]> git.openstreetmap.org Git - nominatim.git/blob - nominatim/tokenizer/token_analysis/generic.py
add tests for migration
[nominatim.git] / nominatim / tokenizer / token_analysis / generic.py
1 """
2 Generic processor for names that creates abbreviation variants.
3 """
4 from collections import defaultdict, namedtuple
5 import itertools
6 import re
7
8 from icu import Transliterator
9 import datrie
10
11 from nominatim.config import flatten_config_list
12 from nominatim.errors import UsageError
13
14 ### Configuration section
15
16 ICUVariant = namedtuple('ICUVariant', ['source', 'replacement'])
17
18 def configure(rules, normalization_rules):
19     """ Extract and preprocess the configuration for this module.
20     """
21     config = {}
22
23     config['replacements'], config['chars'] = _get_variant_config(rules.get('variants'),
24                                                                   normalization_rules)
25     config['variant_only'] = rules.get('mode', '') == 'variant-only'
26
27     return config
28
29
30 def _get_variant_config(rules, normalization_rules):
31     """ Convert the variant definition from the configuration into
32         replacement sets.
33     """
34     immediate = defaultdict(list)
35     chars = set()
36
37     if rules:
38         vset = set()
39         rules = flatten_config_list(rules, 'variants')
40
41         vmaker = _VariantMaker(normalization_rules)
42
43         for section in rules:
44             for rule in (section.get('words') or []):
45                 vset.update(vmaker.compute(rule))
46
47         # Intermediate reorder by source. Also compute required character set.
48         for variant in vset:
49             if variant.source[-1] == ' ' and variant.replacement[-1] == ' ':
50                 replstr = variant.replacement[:-1]
51             else:
52                 replstr = variant.replacement
53             immediate[variant.source].append(replstr)
54             chars.update(variant.source)
55
56     return list(immediate.items()), ''.join(chars)
57
58
59 class _VariantMaker:
60     """ Generater for all necessary ICUVariants from a single variant rule.
61
62         All text in rules is normalized to make sure the variants match later.
63     """
64
65     def __init__(self, norm_rules):
66         self.norm = Transliterator.createFromRules("rule_loader_normalization",
67                                                    norm_rules)
68
69
70     def compute(self, rule):
71         """ Generator for all ICUVariant tuples from a single variant rule.
72         """
73         parts = re.split(r'(\|)?([=-])>', rule)
74         if len(parts) != 4:
75             raise UsageError("Syntax error in variant rule: " + rule)
76
77         decompose = parts[1] is None
78         src_terms = [self._parse_variant_word(t) for t in parts[0].split(',')]
79         repl_terms = (self.norm.transliterate(t).strip() for t in parts[3].split(','))
80
81         # If the source should be kept, add a 1:1 replacement
82         if parts[2] == '-':
83             for src in src_terms:
84                 if src:
85                     for froms, tos in _create_variants(*src, src[0], decompose):
86                         yield ICUVariant(froms, tos)
87
88         for src, repl in itertools.product(src_terms, repl_terms):
89             if src and repl:
90                 for froms, tos in _create_variants(*src, repl, decompose):
91                     yield ICUVariant(froms, tos)
92
93
94     def _parse_variant_word(self, name):
95         name = name.strip()
96         match = re.fullmatch(r'([~^]?)([^~$^]*)([~$]?)', name)
97         if match is None or (match.group(1) == '~' and match.group(3) == '~'):
98             raise UsageError("Invalid variant word descriptor '{}'".format(name))
99         norm_name = self.norm.transliterate(match.group(2)).strip()
100         if not norm_name:
101             return None
102
103         return norm_name, match.group(1), match.group(3)
104
105
106 _FLAG_MATCH = {'^': '^ ',
107                '$': ' ^',
108                '': ' '}
109
110
111 def _create_variants(src, preflag, postflag, repl, decompose):
112     if preflag == '~':
113         postfix = _FLAG_MATCH[postflag]
114         # suffix decomposition
115         src = src + postfix
116         repl = repl + postfix
117
118         yield src, repl
119         yield ' ' + src, ' ' + repl
120
121         if decompose:
122             yield src, ' ' + repl
123             yield ' ' + src, repl
124     elif postflag == '~':
125         # prefix decomposition
126         prefix = _FLAG_MATCH[preflag]
127         src = prefix + src
128         repl = prefix + repl
129
130         yield src, repl
131         yield src + ' ', repl + ' '
132
133         if decompose:
134             yield src, repl + ' '
135             yield src + ' ', repl
136     else:
137         prefix = _FLAG_MATCH[preflag]
138         postfix = _FLAG_MATCH[postflag]
139
140         yield prefix + src + postfix, prefix + repl + postfix
141
142
143 ### Analysis section
144
145 def create(transliterator, config):
146     """ Create a new token analysis instance for this module.
147     """
148     return GenericTokenAnalysis(transliterator, config)
149
150
151 class GenericTokenAnalysis:
152     """ Collects the different transformation rules for normalisation of names
153         and provides the functions to apply the transformations.
154     """
155
156     def __init__(self, to_ascii, config):
157         self.to_ascii = to_ascii
158         self.variant_only = config['variant_only']
159
160         # Set up datrie
161         if config['replacements']:
162             self.replacements = datrie.Trie(config['chars'])
163             for src, repllist in config['replacements']:
164                 self.replacements[src] = repllist
165         else:
166             self.replacements = None
167
168
169     def get_variants_ascii(self, norm_name):
170         """ Compute the spelling variants for the given normalized name
171             and transliterate the result.
172         """
173         baseform = '^ ' + norm_name + ' ^'
174         partials = ['']
175
176         startpos = 0
177         if self.replacements is not None:
178             pos = 0
179             force_space = False
180             while pos < len(baseform):
181                 full, repl = self.replacements.longest_prefix_item(baseform[pos:],
182                                                                    (None, None))
183                 if full is not None:
184                     done = baseform[startpos:pos]
185                     partials = [v + done + r
186                                 for v, r in itertools.product(partials, repl)
187                                 if not force_space or r.startswith(' ')]
188                     if len(partials) > 128:
189                         # If too many variants are produced, they are unlikely
190                         # to be helpful. Only use the original term.
191                         startpos = 0
192                         break
193                     startpos = pos + len(full)
194                     if full[-1] == ' ':
195                         startpos -= 1
196                         force_space = True
197                     pos = startpos
198                 else:
199                     pos += 1
200                     force_space = False
201
202         # No variants detected? Fast return.
203         if startpos == 0:
204             if self.variant_only:
205                 return []
206
207             trans_name = self.to_ascii.transliterate(norm_name).strip()
208             return [trans_name] if trans_name else []
209
210         return self._compute_result_set(partials, baseform[startpos:],
211                                         norm_name if self.variant_only else '')
212
213
214     def _compute_result_set(self, partials, prefix, exclude):
215         results = set()
216
217         for variant in partials:
218             vname = (variant + prefix)[1:-1].strip()
219             if vname != exclude:
220                 trans_name = self.to_ascii.transliterate(vname).strip()
221                 if trans_name:
222                     results.add(trans_name)
223
224         return list(results)