]> git.openstreetmap.org Git - nominatim.git/blob - src/nominatim_db/tokenizer/token_analysis/generic_mutation.py
Merge pull request #3582 from lonvia/switch-to-flake
[nominatim.git] / src / nominatim_db / tokenizer / token_analysis / generic_mutation.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2024 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Creator for mutation variants for the generic token analysis.
9 """
10 from typing import Sequence, Iterable, Iterator, Tuple
11 import itertools
12 import logging
13 import re
14
15 from ...errors import UsageError
16
17 LOG = logging.getLogger()
18
19
20 def _zigzag(outer: Iterable[str], inner: Iterable[str]) -> Iterator[str]:
21     return itertools.chain.from_iterable(itertools.zip_longest(outer, inner, fillvalue=''))
22
23
24 class MutationVariantGenerator:
25     """ Generates name variants by applying a regular expression to the name
26         and replacing it with one or more variants. When the regular expression
27         matches more than once, each occurrence is replaced with all replacement
28         patterns.
29     """
30
31     def __init__(self, pattern: str, replacements: Sequence[str]):
32         self.pattern = re.compile(pattern)
33         self.replacements = replacements
34
35         if self.pattern.groups > 0:
36             LOG.fatal("The mutation pattern %s contains a capturing group. "
37                       "This is not allowed.", pattern)
38             raise UsageError("Bad mutation pattern in configuration.")
39
40     def generate(self, names: Iterable[str]) -> Iterator[str]:
41         """ Generator function for the name variants. 'names' is an iterable
42             over a set of names for which the variants are to be generated.
43         """
44         for name in names:
45             parts = self.pattern.split(name)
46             if len(parts) == 1:
47                 yield name
48             else:
49                 for seps in self._fillers(len(parts)):
50                     yield ''.join(_zigzag(parts, seps))
51
52     def _fillers(self, num_parts: int) -> Iterator[Tuple[str, ...]]:
53         """ Returns a generator for strings to join the given number of string
54             parts in all possible combinations.
55         """
56         return itertools.product(self.replacements, repeat=num_parts - 1)