]> git.openstreetmap.org Git - nominatim.git/blob - src/nominatim_db/tokenizer/token_analysis/generic.py
fix style issue found by flake8
[nominatim.git] / src / nominatim_db / tokenizer / token_analysis / generic.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2024 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Generic processor for names that creates abbreviation variants.
9 """
10 from typing import Mapping, Dict, Any, Iterable, Iterator, Optional, List, cast
11 import itertools
12
13 import datrie
14
15 from ...errors import UsageError
16 from ...data.place_name import PlaceName
17 from .config_variants import get_variant_config
18 from .generic_mutation import MutationVariantGenerator
19
20 # Configuration section
21
22
23 def configure(rules: Mapping[str, Any], normalizer: Any, _: Any) -> Dict[str, Any]:
24     """ Extract and preprocess the configuration for this module.
25     """
26     config: Dict[str, Any] = {}
27
28     config['replacements'], config['chars'] = get_variant_config(rules.get('variants'),
29                                                                  normalizer)
30     config['variant_only'] = rules.get('mode', '') == 'variant-only'
31
32     # parse mutation rules
33     config['mutations'] = []
34     for rule in rules.get('mutations', []):
35         if 'pattern' not in rule:
36             raise UsageError("Missing field 'pattern' in mutation configuration.")
37         if not isinstance(rule['pattern'], str):
38             raise UsageError("Field 'pattern' in mutation configuration "
39                              "must be a simple text field.")
40         if 'replacements' not in rule:
41             raise UsageError("Missing field 'replacements' in mutation configuration.")
42         if not isinstance(rule['replacements'], list):
43             raise UsageError("Field 'replacements' in mutation configuration "
44                              "must be a list of texts.")
45
46         config['mutations'].append((rule['pattern'], rule['replacements']))
47
48     return config
49
50
51 # Analysis section
52
53 def create(normalizer: Any, transliterator: Any,
54            config: Mapping[str, Any]) -> 'GenericTokenAnalysis':
55     """ Create a new token analysis instance for this module.
56     """
57     return GenericTokenAnalysis(normalizer, transliterator, config)
58
59
60 class GenericTokenAnalysis:
61     """ Collects the different transformation rules for normalisation of names
62         and provides the functions to apply the transformations.
63     """
64
65     def __init__(self, norm: Any, to_ascii: Any, config: Mapping[str, Any]) -> None:
66         self.norm = norm
67         self.to_ascii = to_ascii
68         self.variant_only = config['variant_only']
69
70         # Set up datrie
71         if config['replacements']:
72             self.replacements = datrie.Trie(config['chars'])
73             for src, repllist in config['replacements']:
74                 self.replacements[src] = repllist
75         else:
76             self.replacements = None
77
78         # set up mutation rules
79         self.mutations = [MutationVariantGenerator(*cfg) for cfg in config['mutations']]
80
81     def get_canonical_id(self, name: PlaceName) -> str:
82         """ Return the normalized form of the name. This is the standard form
83             from which possible variants for the name can be derived.
84         """
85         return cast(str, self.norm.transliterate(name.name)).strip()
86
87     def compute_variants(self, norm_name: str) -> List[str]:
88         """ Compute the spelling variants for the given normalized name
89             and transliterate the result.
90         """
91         variants = self._generate_word_variants(norm_name)
92
93         for mutation in self.mutations:
94             variants = mutation.generate(variants)
95
96         return [name for name in self._transliterate_unique_list(norm_name, variants) if name]
97
98     def _transliterate_unique_list(self, norm_name: str,
99                                    iterable: Iterable[str]) -> Iterator[Optional[str]]:
100         seen = set()
101         if self.variant_only:
102             seen.add(norm_name)
103
104         for variant in map(str.strip, iterable):
105             if variant not in seen:
106                 seen.add(variant)
107                 yield self.to_ascii.transliterate(variant).strip()
108
109     def _generate_word_variants(self, norm_name: str) -> Iterable[str]:
110         baseform = '^ ' + norm_name + ' ^'
111         baselen = len(baseform)
112         partials = ['']
113
114         startpos = 0
115         if self.replacements is not None:
116             pos = 0
117             force_space = False
118             while pos < baselen:
119                 full, repl = self.replacements.longest_prefix_item(baseform[pos:],
120                                                                    (None, None))
121                 if full is not None:
122                     done = baseform[startpos:pos]
123                     partials = [v + done + r
124                                 for v, r in itertools.product(partials, repl)
125                                 if not force_space or r.startswith(' ')]
126                     if len(partials) > 128:
127                         # If too many variants are produced, they are unlikely
128                         # to be helpful. Only use the original term.
129                         startpos = 0
130                         break
131                     startpos = pos + len(full)
132                     if full[-1] == ' ':
133                         startpos -= 1
134                         force_space = True
135                     pos = startpos
136                 else:
137                     pos += 1
138                     force_space = False
139
140         # No variants detected? Fast return.
141         if startpos == 0:
142             return (norm_name, )
143
144         if startpos < baselen:
145             return (part[1:] + baseform[startpos:-1] for part in partials)
146
147         return (part[1:-1] for part in partials)