]> git.openstreetmap.org Git - nominatim.git/blob - nominatim/tokenizer/token_analysis/generic.py
Added missing return types to functions
[nominatim.git] / nominatim / tokenizer / token_analysis / generic.py
1 # SPDX-License-Identifier: GPL-2.0-only
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2022 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Generic processor for names that creates abbreviation variants.
9 """
10 from typing import Mapping, Dict, Any, Iterable, Iterator, Optional, List, cast
11 import itertools
12
13 import datrie
14
15 from nominatim.errors import UsageError
16 from nominatim.data.place_name import PlaceName
17 from nominatim.tokenizer.token_analysis.config_variants import get_variant_config
18 from nominatim.tokenizer.token_analysis.generic_mutation import MutationVariantGenerator
19
20 ### Configuration section
21
22 def configure(rules: Mapping[str, Any], normalizer: Any, _: Any) -> Dict[str, Any]:
23     """ Extract and preprocess the configuration for this module.
24     """
25     config: Dict[str, Any] = {}
26
27     config['replacements'], config['chars'] = get_variant_config(rules.get('variants'),
28                                                                  normalizer)
29     config['variant_only'] = rules.get('mode', '') == 'variant-only'
30
31     # parse mutation rules
32     config['mutations'] = []
33     for rule in rules.get('mutations', []):
34         if 'pattern' not in rule:
35             raise UsageError("Missing field 'pattern' in mutation configuration.")
36         if not isinstance(rule['pattern'], str):
37             raise UsageError("Field 'pattern' in mutation configuration "
38                              "must be a simple text field.")
39         if 'replacements' not in rule:
40             raise UsageError("Missing field 'replacements' in mutation configuration.")
41         if not isinstance(rule['replacements'], list):
42             raise UsageError("Field 'replacements' in mutation configuration "
43                              "must be a list of texts.")
44
45         config['mutations'].append((rule['pattern'], rule['replacements']))
46
47     return config
48
49
50 ### Analysis section
51
52 def create(normalizer: Any, transliterator: Any,
53            config: Mapping[str, Any]) -> 'GenericTokenAnalysis':
54     """ Create a new token analysis instance for this module.
55     """
56     return GenericTokenAnalysis(normalizer, transliterator, config)
57
58
59 class GenericTokenAnalysis:
60     """ Collects the different transformation rules for normalisation of names
61         and provides the functions to apply the transformations.
62     """
63
64     def __init__(self, norm: Any, to_ascii: Any, config: Mapping[str, Any]) -> None:
65         self.norm = norm
66         self.to_ascii = to_ascii
67         self.variant_only = config['variant_only']
68
69         # Set up datrie
70         if config['replacements']:
71             self.replacements = datrie.Trie(config['chars'])
72             for src, repllist in config['replacements']:
73                 self.replacements[src] = repllist
74         else:
75             self.replacements = None
76
77         # set up mutation rules
78         self.mutations = [MutationVariantGenerator(*cfg) for cfg in config['mutations']]
79
80
81     def get_canonical_id(self, name: PlaceName) -> str:
82         """ Return the normalized form of the name. This is the standard form
83             from which possible variants for the name can be derived.
84         """
85         return cast(str, self.norm.transliterate(name.name)).strip()
86
87
88     def compute_variants(self, norm_name: str) -> List[str]:
89         """ Compute the spelling variants for the given normalized name
90             and transliterate the result.
91         """
92         variants = self._generate_word_variants(norm_name)
93
94         for mutation in self.mutations:
95             variants = mutation.generate(variants)
96
97         return [name for name in self._transliterate_unique_list(norm_name, variants) if name]
98
99
100     def _transliterate_unique_list(self, norm_name: str,
101                                    iterable: Iterable[str]) -> Iterator[Optional[str]]:
102         seen = set()
103         if self.variant_only:
104             seen.add(norm_name)
105
106         for variant in map(str.strip, iterable):
107             if variant not in seen:
108                 seen.add(variant)
109                 yield self.to_ascii.transliterate(variant).strip()
110
111
112     def _generate_word_variants(self, norm_name: str) -> Iterable[str]:
113         baseform = '^ ' + norm_name + ' ^'
114         baselen = len(baseform)
115         partials = ['']
116
117         startpos = 0
118         if self.replacements is not None:
119             pos = 0
120             force_space = False
121             while pos < baselen:
122                 full, repl = self.replacements.longest_prefix_item(baseform[pos:],
123                                                                    (None, None))
124                 if full is not None:
125                     done = baseform[startpos:pos]
126                     partials = [v + done + r
127                                 for v, r in itertools.product(partials, repl)
128                                 if not force_space or r.startswith(' ')]
129                     if len(partials) > 128:
130                         # If too many variants are produced, they are unlikely
131                         # to be helpful. Only use the original term.
132                         startpos = 0
133                         break
134                     startpos = pos + len(full)
135                     if full[-1] == ' ':
136                         startpos -= 1
137                         force_space = True
138                     pos = startpos
139                 else:
140                     pos += 1
141                     force_space = False
142
143         # No variants detected? Fast return.
144         if startpos == 0:
145             return (norm_name, )
146
147         if startpos < baselen:
148             return (part[1:] + baseform[startpos:-1] for part in partials)
149
150         return (part[1:-1] for part in partials)