]> git.openstreetmap.org Git - nominatim.git/blob - src/nominatim_api/search/icu_tokenizer.py
d52614fdaf4863b0aaa2e2721659988d9c03470d
[nominatim.git] / src / nominatim_api / search / icu_tokenizer.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2024 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Implementation of query analysis for the ICU tokenizer.
9 """
10 from typing import Tuple, Dict, List, Optional, NamedTuple, Iterator, Any, cast
11 from collections import defaultdict
12 import dataclasses
13 import difflib
14 import re
15 from itertools import zip_longest
16
17 from icu import Transliterator
18
19 import sqlalchemy as sa
20
21 from ..errors import UsageError
22 from ..typing import SaRow
23 from ..sql.sqlalchemy_types import Json
24 from ..connection import SearchConnection
25 from ..logging import log
26 from . import query as qmod
27 from ..query_preprocessing.config import QueryConfig
28 from .query_analyzer_factory import AbstractQueryAnalyzer
29
30
31 DB_TO_TOKEN_TYPE = {
32     'W': qmod.TokenType.WORD,
33     'w': qmod.TokenType.PARTIAL,
34     'H': qmod.TokenType.HOUSENUMBER,
35     'P': qmod.TokenType.POSTCODE,
36     'C': qmod.TokenType.COUNTRY
37 }
38
39
40 class QueryPart(NamedTuple):
41     """ Normalized and transliterated form of a single term in the query.
42         When the term came out of a split during the transliteration,
43         the normalized string is the full word before transliteration.
44         The word number keeps track of the word before transliteration
45         and can be used to identify partial transliterated terms.
46     """
47     token: str
48     normalized: str
49     word_number: int
50
51
52 QueryParts = List[QueryPart]
53 WordDict = Dict[str, List[qmod.TokenRange]]
54
55
56 def yield_words(terms: List[QueryPart], start: int) -> Iterator[Tuple[str, qmod.TokenRange]]:
57     """ Return all combinations of words in the terms list after the
58         given position.
59     """
60     total = len(terms)
61     for first in range(start, total):
62         word = terms[first].token
63         yield word, qmod.TokenRange(first, first + 1)
64         for last in range(first + 1, min(first + 20, total)):
65             word = ' '.join((word, terms[last].token))
66             yield word, qmod.TokenRange(first, last + 1)
67
68
69 @dataclasses.dataclass
70 class ICUToken(qmod.Token):
71     """ Specialised token for ICU tokenizer.
72     """
73     word_token: str
74     info: Optional[Dict[str, Any]]
75
76     def get_category(self) -> Tuple[str, str]:
77         assert self.info
78         return self.info.get('class', ''), self.info.get('type', '')
79
80     def rematch(self, norm: str) -> None:
81         """ Check how well the token matches the given normalized string
82             and add a penalty, if necessary.
83         """
84         if not self.lookup_word:
85             return
86
87         seq = difflib.SequenceMatcher(a=self.lookup_word, b=norm)
88         distance = 0
89         for tag, afrom, ato, bfrom, bto in seq.get_opcodes():
90             if tag in ('delete', 'insert') and (afrom == 0 or ato == len(self.lookup_word)):
91                 distance += 1
92             elif tag == 'replace':
93                 distance += max((ato-afrom), (bto-bfrom))
94             elif tag != 'equal':
95                 distance += abs((ato-afrom) - (bto-bfrom))
96         self.penalty += (distance/len(self.lookup_word))
97
98     @staticmethod
99     def from_db_row(row: SaRow) -> 'ICUToken':
100         """ Create a ICUToken from the row of the word table.
101         """
102         count = 1 if row.info is None else row.info.get('count', 1)
103         addr_count = 1 if row.info is None else row.info.get('addr_count', 1)
104
105         penalty = 0.0
106         if row.type == 'w':
107             penalty = 0.3
108         elif row.type == 'W':
109             if len(row.word_token) == 1 and row.word_token == row.word:
110                 penalty = 0.2 if row.word.isdigit() else 0.3
111         elif row.type == 'H':
112             penalty = sum(0.1 for c in row.word_token if c != ' ' and not c.isdigit())
113             if all(not c.isdigit() for c in row.word_token):
114                 penalty += 0.2 * (len(row.word_token) - 1)
115         elif row.type == 'C':
116             if len(row.word_token) == 1:
117                 penalty = 0.3
118
119         if row.info is None:
120             lookup_word = row.word
121         else:
122             lookup_word = row.info.get('lookup', row.word)
123         if lookup_word:
124             lookup_word = lookup_word.split('@', 1)[0]
125         else:
126             lookup_word = row.word_token
127
128         return ICUToken(penalty=penalty, token=row.word_id, count=max(1, count),
129                         lookup_word=lookup_word,
130                         word_token=row.word_token, info=row.info,
131                         addr_count=max(1, addr_count))
132
133
134 class ICUQueryAnalyzer(AbstractQueryAnalyzer):
135     """ Converter for query strings into a tokenized query
136         using the tokens created by a ICU tokenizer.
137     """
138     def __init__(self, conn: SearchConnection) -> None:
139         self.conn = conn
140
141     async def setup(self) -> None:
142         """ Set up static data structures needed for the analysis.
143         """
144         async def _make_normalizer() -> Any:
145             rules = await self.conn.get_property('tokenizer_import_normalisation')
146             return Transliterator.createFromRules("normalization", rules)
147
148         self.normalizer = await self.conn.get_cached_value('ICUTOK', 'normalizer',
149                                                            _make_normalizer)
150
151         async def _make_transliterator() -> Any:
152             rules = await self.conn.get_property('tokenizer_import_transliteration')
153             return Transliterator.createFromRules("transliteration", rules)
154
155         self.transliterator = await self.conn.get_cached_value('ICUTOK', 'transliterator',
156                                                                _make_transliterator)
157
158         await self._setup_preprocessing()
159
160         if 'word' not in self.conn.t.meta.tables:
161             sa.Table('word', self.conn.t.meta,
162                      sa.Column('word_id', sa.Integer),
163                      sa.Column('word_token', sa.Text, nullable=False),
164                      sa.Column('type', sa.Text, nullable=False),
165                      sa.Column('word', sa.Text),
166                      sa.Column('info', Json))
167
168     async def _setup_preprocessing(self) -> None:
169         """ Load the rules for preprocessing and set up the handlers.
170         """
171
172         rules = self.conn.config.load_sub_configuration('icu_tokenizer.yaml',
173                                                         config='TOKENIZER_CONFIG')
174         preprocessing_rules = rules.get('query-preprocessing', [])
175
176         self.preprocessors = []
177
178         for func in preprocessing_rules:
179             if 'step' not in func:
180                 raise UsageError("Preprocessing rule is missing the 'step' attribute.")
181             if not isinstance(func['step'], str):
182                 raise UsageError("'step' attribute must be a simple string.")
183
184             module = self.conn.config.load_plugin_module(
185                         func['step'], 'nominatim_api.query_preprocessing')
186             self.preprocessors.append(
187                 module.create(QueryConfig(func).set_normalizer(self.normalizer)))
188
189     async def analyze_query(self, phrases: List[qmod.Phrase]) -> qmod.QueryStruct:
190         """ Analyze the given list of phrases and return the
191             tokenized query.
192         """
193         log().section('Analyze query (using ICU tokenizer)')
194         for func in self.preprocessors:
195             phrases = func(phrases)
196         query = qmod.QueryStruct(phrases)
197
198         log().var_dump('Normalized query', query.source)
199         if not query.source:
200             return query
201
202         parts, words = self.split_query(query)
203         log().var_dump('Transliterated query', lambda: _dump_transliterated(query, parts))
204
205         for row in await self.lookup_in_db(list(words.keys())):
206             for trange in words[row.word_token]:
207                 token = ICUToken.from_db_row(row)
208                 if row.type == 'S':
209                     if row.info['op'] in ('in', 'near'):
210                         if trange.start == 0:
211                             query.add_token(trange, qmod.TokenType.NEAR_ITEM, token)
212                     else:
213                         if trange.start == 0 and trange.end == query.num_token_slots():
214                             query.add_token(trange, qmod.TokenType.NEAR_ITEM, token)
215                         else:
216                             query.add_token(trange, qmod.TokenType.QUALIFIER, token)
217                 else:
218                     query.add_token(trange, DB_TO_TOKEN_TYPE[row.type], token)
219
220         self.add_extra_tokens(query, parts)
221         self.rerank_tokens(query, parts)
222
223         log().table_dump('Word tokens', _dump_word_tokens(query))
224
225         return query
226
227     def normalize_text(self, text: str) -> str:
228         """ Bring the given text into a normalized form. That is the
229             standardized form search will work with. All information removed
230             at this stage is inevitably lost.
231         """
232         return cast(str, self.normalizer.transliterate(text))
233
234     def split_query(self, query: qmod.QueryStruct) -> Tuple[QueryParts, WordDict]:
235         """ Transliterate the phrases and split them into tokens.
236
237             Returns the list of transliterated tokens together with their
238             normalized form and a dictionary of words for lookup together
239             with their position.
240         """
241         parts: QueryParts = []
242         phrase_start = 0
243         words = defaultdict(list)
244         wordnr = 0
245         for phrase in query.source:
246             query.nodes[-1].ptype = phrase.ptype
247             phrase_split = re.split('([ :-])', phrase.text)
248             # The zip construct will give us the pairs of word/break from
249             # the regular expression split. As the split array ends on the
250             # final word, we simply use the fillvalue to even out the list and
251             # add the phrase break at the end.
252             for word, breakchar in zip_longest(*[iter(phrase_split)]*2, fillvalue=','):
253                 if not word:
254                     continue
255                 trans = self.transliterator.transliterate(word)
256                 if trans:
257                     for term in trans.split(' '):
258                         if term:
259                             parts.append(QueryPart(term, word, wordnr))
260                             query.add_node(qmod.BreakType.TOKEN, phrase.ptype)
261                     query.nodes[-1].btype = qmod.BreakType(breakchar)
262                 wordnr += 1
263
264             for word, wrange in yield_words(parts, phrase_start):
265                 words[word].append(wrange)
266
267             phrase_start = len(parts)
268         query.nodes[-1].btype = qmod.BreakType.END
269
270         return parts, words
271
272     async def lookup_in_db(self, words: List[str]) -> 'sa.Result[Any]':
273         """ Return the token information from the database for the
274             given word tokens.
275         """
276         t = self.conn.t.meta.tables['word']
277         return await self.conn.execute(t.select().where(t.c.word_token.in_(words)))
278
279     def add_extra_tokens(self, query: qmod.QueryStruct, parts: QueryParts) -> None:
280         """ Add tokens to query that are not saved in the database.
281         """
282         for part, node, i in zip(parts, query.nodes, range(1000)):
283             if len(part.token) <= 4 and part[0].isdigit()\
284                and not node.has_tokens(i+1, qmod.TokenType.HOUSENUMBER):
285                 query.add_token(qmod.TokenRange(i, i+1), qmod.TokenType.HOUSENUMBER,
286                                 ICUToken(penalty=0.5, token=0,
287                                          count=1, addr_count=1, lookup_word=part.token,
288                                          word_token=part.token, info=None))
289
290     def rerank_tokens(self, query: qmod.QueryStruct, parts: QueryParts) -> None:
291         """ Add penalties to tokens that depend on presence of other token.
292         """
293         for i, node, tlist in query.iter_token_lists():
294             if tlist.ttype == qmod.TokenType.POSTCODE:
295                 for repl in node.starting:
296                     if repl.end == tlist.end and repl.ttype != qmod.TokenType.POSTCODE \
297                        and (repl.ttype != qmod.TokenType.HOUSENUMBER
298                             or len(tlist.tokens[0].lookup_word) > 4):
299                         repl.add_penalty(0.39)
300             elif (tlist.ttype == qmod.TokenType.HOUSENUMBER
301                   and len(tlist.tokens[0].lookup_word) <= 3):
302                 if any(c.isdigit() for c in tlist.tokens[0].lookup_word):
303                     for repl in node.starting:
304                         if repl.end == tlist.end and repl.ttype != qmod.TokenType.HOUSENUMBER:
305                             repl.add_penalty(0.5 - tlist.tokens[0].penalty)
306             elif tlist.ttype not in (qmod.TokenType.COUNTRY, qmod.TokenType.PARTIAL):
307                 norm = parts[i].normalized
308                 for j in range(i + 1, tlist.end):
309                     if parts[j - 1].word_number != parts[j].word_number:
310                         norm += '  ' + parts[j].normalized
311                 for token in tlist.tokens:
312                     cast(ICUToken, token).rematch(norm)
313
314
315 def _dump_transliterated(query: qmod.QueryStruct, parts: QueryParts) -> str:
316     out = query.nodes[0].btype.value
317     for node, part in zip(query.nodes[1:], parts):
318         out += part.token + node.btype.value
319     return out
320
321
322 def _dump_word_tokens(query: qmod.QueryStruct) -> Iterator[List[Any]]:
323     yield ['type', 'token', 'word_token', 'lookup_word', 'penalty', 'count', 'info']
324     for node in query.nodes:
325         for tlist in node.starting:
326             for token in tlist.tokens:
327                 t = cast(ICUToken, token)
328                 yield [tlist.ttype.name, t.token, t.word_token or '',
329                        t.lookup_word or '', t.penalty, t.count, t.info]
330
331
332 async def create_query_analyzer(conn: SearchConnection) -> AbstractQueryAnalyzer:
333     """ Create and set up a new query analyzer for a database based
334         on the ICU tokenizer.
335     """
336     out = ICUQueryAnalyzer(conn)
337     await out.setup()
338
339     return out