]> git.openstreetmap.org Git - nominatim.git/blob - nominatim/tokenizer/token_analysis/generic.py
add type annotations for token analysis
[nominatim.git] / nominatim / tokenizer / token_analysis / generic.py
1 # SPDX-License-Identifier: GPL-2.0-only
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2022 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Generic processor for names that creates abbreviation variants.
9 """
10 from typing import Mapping, Dict, Any, Iterable, Iterator, Optional, List, cast
11 import itertools
12
13 import datrie
14
15 from nominatim.errors import UsageError
16 from nominatim.tokenizer.token_analysis.config_variants import get_variant_config
17 from nominatim.tokenizer.token_analysis.generic_mutation import MutationVariantGenerator
18
19 ### Configuration section
20
21 def configure(rules: Mapping[str, Any], normalization_rules: str) -> Dict[str, Any]:
22     """ Extract and preprocess the configuration for this module.
23     """
24     config: Dict[str, Any] = {}
25
26     config['replacements'], config['chars'] = get_variant_config(rules.get('variants'),
27                                                                  normalization_rules)
28     config['variant_only'] = rules.get('mode', '') == 'variant-only'
29
30     # parse mutation rules
31     config['mutations'] = []
32     for rule in rules.get('mutations', []):
33         if 'pattern' not in rule:
34             raise UsageError("Missing field 'pattern' in mutation configuration.")
35         if not isinstance(rule['pattern'], str):
36             raise UsageError("Field 'pattern' in mutation configuration "
37                              "must be a simple text field.")
38         if 'replacements' not in rule:
39             raise UsageError("Missing field 'replacements' in mutation configuration.")
40         if not isinstance(rule['replacements'], list):
41             raise UsageError("Field 'replacements' in mutation configuration "
42                              "must be a list of texts.")
43
44         config['mutations'].append((rule['pattern'], rule['replacements']))
45
46     return config
47
48
49 ### Analysis section
50
51 def create(normalizer: Any, transliterator: Any,
52            config: Mapping[str, Any]) -> 'GenericTokenAnalysis':
53     """ Create a new token analysis instance for this module.
54     """
55     return GenericTokenAnalysis(normalizer, transliterator, config)
56
57
58 class GenericTokenAnalysis:
59     """ Collects the different transformation rules for normalisation of names
60         and provides the functions to apply the transformations.
61     """
62
63     def __init__(self, norm: Any, to_ascii: Any, config: Mapping[str, Any]) -> None:
64         self.norm = norm
65         self.to_ascii = to_ascii
66         self.variant_only = config['variant_only']
67
68         # Set up datrie
69         if config['replacements']:
70             self.replacements = datrie.Trie(config['chars'])
71             for src, repllist in config['replacements']:
72                 self.replacements[src] = repllist
73         else:
74             self.replacements = None
75
76         # set up mutation rules
77         self.mutations = [MutationVariantGenerator(*cfg) for cfg in config['mutations']]
78
79
80     def normalize(self, name: str) -> str:
81         """ Return the normalized form of the name. This is the standard form
82             from which possible variants for the name can be derived.
83         """
84         return cast(str, self.norm.transliterate(name)).strip()
85
86
87     def get_variants_ascii(self, norm_name: str) -> List[str]:
88         """ Compute the spelling variants for the given normalized name
89             and transliterate the result.
90         """
91         variants = self._generate_word_variants(norm_name)
92
93         for mutation in self.mutations:
94             variants = mutation.generate(variants)
95
96         return [name for name in self._transliterate_unique_list(norm_name, variants) if name]
97
98
99     def _transliterate_unique_list(self, norm_name: str,
100                                    iterable: Iterable[str]) -> Iterator[Optional[str]]:
101         seen = set()
102         if self.variant_only:
103             seen.add(norm_name)
104
105         for variant in map(str.strip, iterable):
106             if variant not in seen:
107                 seen.add(variant)
108                 yield self.to_ascii.transliterate(variant).strip()
109
110
111     def _generate_word_variants(self, norm_name: str) -> Iterable[str]:
112         baseform = '^ ' + norm_name + ' ^'
113         baselen = len(baseform)
114         partials = ['']
115
116         startpos = 0
117         if self.replacements is not None:
118             pos = 0
119             force_space = False
120             while pos < baselen:
121                 full, repl = self.replacements.longest_prefix_item(baseform[pos:],
122                                                                    (None, None))
123                 if full is not None:
124                     done = baseform[startpos:pos]
125                     partials = [v + done + r
126                                 for v, r in itertools.product(partials, repl)
127                                 if not force_space or r.startswith(' ')]
128                     if len(partials) > 128:
129                         # If too many variants are produced, they are unlikely
130                         # to be helpful. Only use the original term.
131                         startpos = 0
132                         break
133                     startpos = pos + len(full)
134                     if full[-1] == ' ':
135                         startpos -= 1
136                         force_space = True
137                     pos = startpos
138                 else:
139                     pos += 1
140                     force_space = False
141
142         # No variants detected? Fast return.
143         if startpos == 0:
144             return (norm_name, )
145
146         if startpos < baselen:
147             return (part[1:] + baseform[startpos:-1] for part in partials)
148
149         return (part[1:-1] for part in partials)