]> git.openstreetmap.org Git - nominatim.git/blob - src/nominatim_api/search/icu_tokenizer.py
Merge remote-tracking branch 'upstream/master'
[nominatim.git] / src / nominatim_api / search / icu_tokenizer.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2024 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Implementation of query analysis for the ICU tokenizer.
9 """
10 from typing import Tuple, Dict, List, Optional, Iterator, Any, cast
11 from collections import defaultdict
12 import dataclasses
13 import difflib
14 import re
15 from itertools import zip_longest
16
17 from icu import Transliterator
18
19 import sqlalchemy as sa
20
21 from ..errors import UsageError
22 from ..typing import SaRow
23 from ..sql.sqlalchemy_types import Json
24 from ..connection import SearchConnection
25 from ..logging import log
26 from . import query as qmod
27 from ..query_preprocessing.config import QueryConfig
28 from .query_analyzer_factory import AbstractQueryAnalyzer
29
30
31 DB_TO_TOKEN_TYPE = {
32     'W': qmod.TokenType.WORD,
33     'w': qmod.TokenType.PARTIAL,
34     'H': qmod.TokenType.HOUSENUMBER,
35     'P': qmod.TokenType.POSTCODE,
36     'C': qmod.TokenType.COUNTRY
37 }
38
39 PENALTY_IN_TOKEN_BREAK = {
40      qmod.BreakType.START: 0.5,
41      qmod.BreakType.END: 0.5,
42      qmod.BreakType.PHRASE: 0.5,
43      qmod.BreakType.SOFT_PHRASE: 0.5,
44      qmod.BreakType.WORD: 0.1,
45      qmod.BreakType.PART: 0.0,
46      qmod.BreakType.TOKEN: 0.0
47 }
48
49
50 @dataclasses.dataclass
51 class QueryPart:
52     """ Normalized and transliterated form of a single term in the query.
53         When the term came out of a split during the transliteration,
54         the normalized string is the full word before transliteration.
55         The word number keeps track of the word before transliteration
56         and can be used to identify partial transliterated terms.
57         Penalty is the break penalty for the break following the token.
58     """
59     token: str
60     normalized: str
61     word_number: int
62     penalty: float
63
64
65 QueryParts = List[QueryPart]
66 WordDict = Dict[str, List[qmod.TokenRange]]
67
68
69 def yield_words(terms: List[QueryPart], start: int) -> Iterator[Tuple[str, qmod.TokenRange]]:
70     """ Return all combinations of words in the terms list after the
71         given position.
72     """
73     total = len(terms)
74     for first in range(start, total):
75         word = terms[first].token
76         penalty = PENALTY_IN_TOKEN_BREAK[qmod.BreakType.WORD]
77         yield word, qmod.TokenRange(first, first + 1, penalty=penalty)
78         for last in range(first + 1, min(first + 20, total)):
79             word = ' '.join((word, terms[last].token))
80             penalty += terms[last - 1].penalty
81             yield word, qmod.TokenRange(first, last + 1, penalty=penalty)
82
83
84 @dataclasses.dataclass
85 class ICUToken(qmod.Token):
86     """ Specialised token for ICU tokenizer.
87     """
88     word_token: str
89     info: Optional[Dict[str, Any]]
90
91     def get_category(self) -> Tuple[str, str]:
92         assert self.info
93         return self.info.get('class', ''), self.info.get('type', '')
94
95     def rematch(self, norm: str) -> None:
96         """ Check how well the token matches the given normalized string
97             and add a penalty, if necessary.
98         """
99         if not self.lookup_word:
100             return
101
102         seq = difflib.SequenceMatcher(a=self.lookup_word, b=norm)
103         distance = 0
104         for tag, afrom, ato, bfrom, bto in seq.get_opcodes():
105             if tag in ('delete', 'insert') and (afrom == 0 or ato == len(self.lookup_word)):
106                 distance += 1
107             elif tag == 'replace':
108                 distance += max((ato-afrom), (bto-bfrom))
109             elif tag != 'equal':
110                 distance += abs((ato-afrom) - (bto-bfrom))
111         self.penalty += (distance/len(self.lookup_word))
112
113     @staticmethod
114     def from_db_row(row: SaRow, base_penalty: float = 0.0) -> 'ICUToken':
115         """ Create a ICUToken from the row of the word table.
116         """
117         count = 1 if row.info is None else row.info.get('count', 1)
118         addr_count = 1 if row.info is None else row.info.get('addr_count', 1)
119
120         penalty = base_penalty
121         if row.type == 'w':
122             penalty += 0.3
123         elif row.type == 'W':
124             if len(row.word_token) == 1 and row.word_token == row.word:
125                 penalty += 0.2 if row.word.isdigit() else 0.3
126         elif row.type == 'H':
127             penalty += sum(0.1 for c in row.word_token if c != ' ' and not c.isdigit())
128             if all(not c.isdigit() for c in row.word_token):
129                 penalty += 0.2 * (len(row.word_token) - 1)
130         elif row.type == 'C':
131             if len(row.word_token) == 1:
132                 penalty += 0.3
133
134         if row.info is None:
135             lookup_word = row.word
136         else:
137             lookup_word = row.info.get('lookup', row.word)
138         if lookup_word:
139             lookup_word = lookup_word.split('@', 1)[0]
140         else:
141             lookup_word = row.word_token
142
143         return ICUToken(penalty=penalty, token=row.word_id, count=max(1, count),
144                         lookup_word=lookup_word,
145                         word_token=row.word_token, info=row.info,
146                         addr_count=max(1, addr_count))
147
148
149 class ICUQueryAnalyzer(AbstractQueryAnalyzer):
150     """ Converter for query strings into a tokenized query
151         using the tokens created by a ICU tokenizer.
152     """
153     def __init__(self, conn: SearchConnection) -> None:
154         self.conn = conn
155
156     async def setup(self) -> None:
157         """ Set up static data structures needed for the analysis.
158         """
159         async def _make_normalizer() -> Any:
160             rules = await self.conn.get_property('tokenizer_import_normalisation')
161             return Transliterator.createFromRules("normalization", rules)
162
163         self.normalizer = await self.conn.get_cached_value('ICUTOK', 'normalizer',
164                                                            _make_normalizer)
165
166         async def _make_transliterator() -> Any:
167             rules = await self.conn.get_property('tokenizer_import_transliteration')
168             return Transliterator.createFromRules("transliteration", rules)
169
170         self.transliterator = await self.conn.get_cached_value('ICUTOK', 'transliterator',
171                                                                _make_transliterator)
172
173         await self._setup_preprocessing()
174
175         if 'word' not in self.conn.t.meta.tables:
176             sa.Table('word', self.conn.t.meta,
177                      sa.Column('word_id', sa.Integer),
178                      sa.Column('word_token', sa.Text, nullable=False),
179                      sa.Column('type', sa.Text, nullable=False),
180                      sa.Column('word', sa.Text),
181                      sa.Column('info', Json))
182
183     async def _setup_preprocessing(self) -> None:
184         """ Load the rules for preprocessing and set up the handlers.
185         """
186
187         rules = self.conn.config.load_sub_configuration('icu_tokenizer.yaml',
188                                                         config='TOKENIZER_CONFIG')
189         preprocessing_rules = rules.get('query-preprocessing', [])
190
191         self.preprocessors = []
192
193         for func in preprocessing_rules:
194             if 'step' not in func:
195                 raise UsageError("Preprocessing rule is missing the 'step' attribute.")
196             if not isinstance(func['step'], str):
197                 raise UsageError("'step' attribute must be a simple string.")
198
199             module = self.conn.config.load_plugin_module(
200                         func['step'], 'nominatim_api.query_preprocessing')
201             self.preprocessors.append(
202                 module.create(QueryConfig(func).set_normalizer(self.normalizer)))
203
204     async def analyze_query(self, phrases: List[qmod.Phrase]) -> qmod.QueryStruct:
205         """ Analyze the given list of phrases and return the
206             tokenized query.
207         """
208         log().section('Analyze query (using ICU tokenizer)')
209         for func in self.preprocessors:
210             phrases = func(phrases)
211
212         if len(phrases) == 1 \
213                 and phrases[0].text.count(' ') > 3 \
214                 and max(len(s) for s in phrases[0].text.split()) < 3:
215             normalized = []
216
217         query = qmod.QueryStruct(phrases)
218
219         log().var_dump('Normalized query', query.source)
220         if not query.source:
221             return query
222
223         parts, words = self.split_query(query)
224         log().var_dump('Transliterated query', lambda: _dump_transliterated(query, parts))
225
226         for row in await self.lookup_in_db(list(words.keys())):
227             for trange in words[row.word_token]:
228                 token = ICUToken.from_db_row(row, trange.penalty or 0.0)
229                 if row.type == 'S':
230                     if row.info['op'] in ('in', 'near'):
231                         if trange.start == 0:
232                             query.add_token(trange, qmod.TokenType.NEAR_ITEM, token)
233                     else:
234                         if trange.start == 0 and trange.end == query.num_token_slots():
235                             query.add_token(trange, qmod.TokenType.NEAR_ITEM, token)
236                         else:
237                             query.add_token(trange, qmod.TokenType.QUALIFIER, token)
238                 else:
239                     query.add_token(trange, DB_TO_TOKEN_TYPE[row.type], token)
240
241         self.add_extra_tokens(query, parts)
242         self.rerank_tokens(query, parts)
243
244         log().table_dump('Word tokens', _dump_word_tokens(query))
245
246         return query
247
248     def normalize_text(self, text: str) -> str:
249         """ Bring the given text into a normalized form. That is the
250             standardized form search will work with. All information removed
251             at this stage is inevitably lost.
252         """
253         return cast(str, self.normalizer.transliterate(text))
254
255     def split_query(self, query: qmod.QueryStruct) -> Tuple[QueryParts, WordDict]:
256         """ Transliterate the phrases and split them into tokens.
257
258             Returns the list of transliterated tokens together with their
259             normalized form and a dictionary of words for lookup together
260             with their position.
261         """
262         parts: QueryParts = []
263         phrase_start = 0
264         words = defaultdict(list)
265         wordnr = 0
266         for phrase in query.source:
267             query.nodes[-1].ptype = phrase.ptype
268             phrase_split = re.split('([ :-])', phrase.text)
269             # The zip construct will give us the pairs of word/break from
270             # the regular expression split. As the split array ends on the
271             # final word, we simply use the fillvalue to even out the list and
272             # add the phrase break at the end.
273             for word, breakchar in zip_longest(*[iter(phrase_split)]*2, fillvalue=','):
274                 if not word:
275                     continue
276                 trans = self.transliterator.transliterate(word)
277                 if trans:
278                     for term in trans.split(' '):
279                         if term:
280                             parts.append(QueryPart(term, word, wordnr,
281                                                    PENALTY_IN_TOKEN_BREAK[qmod.BreakType.TOKEN]))
282                             query.add_node(qmod.BreakType.TOKEN, phrase.ptype)
283                     query.nodes[-1].btype = qmod.BreakType(breakchar)
284                     parts[-1].penalty = PENALTY_IN_TOKEN_BREAK[qmod.BreakType(breakchar)]
285                 wordnr += 1
286
287             for word, wrange in yield_words(parts, phrase_start):
288                 words[word].append(wrange)
289
290             phrase_start = len(parts)
291         query.nodes[-1].btype = qmod.BreakType.END
292
293         return parts, words
294
295     async def lookup_in_db(self, words: List[str]) -> 'sa.Result[Any]':
296         """ Return the token information from the database for the
297             given word tokens.
298         """
299         t = self.conn.t.meta.tables['word']
300         return await self.conn.execute(t.select().where(t.c.word_token.in_(words)))
301
302     def add_extra_tokens(self, query: qmod.QueryStruct, parts: QueryParts) -> None:
303         """ Add tokens to query that are not saved in the database.
304         """
305         for part, node, i in zip(parts, query.nodes, range(1000)):
306             if len(part.token) <= 4 and part.token.isdigit()\
307                and not node.has_tokens(i+1, qmod.TokenType.HOUSENUMBER):
308                 query.add_token(qmod.TokenRange(i, i+1), qmod.TokenType.HOUSENUMBER,
309                                 ICUToken(penalty=0.5, token=0,
310                                          count=1, addr_count=1, lookup_word=part.token,
311                                          word_token=part.token, info=None))
312
313     def rerank_tokens(self, query: qmod.QueryStruct, parts: QueryParts) -> None:
314         """ Add penalties to tokens that depend on presence of other token.
315         """
316         for i, node, tlist in query.iter_token_lists():
317             if tlist.ttype == qmod.TokenType.POSTCODE:
318                 for repl in node.starting:
319                     if repl.end == tlist.end and repl.ttype != qmod.TokenType.POSTCODE \
320                        and (repl.ttype != qmod.TokenType.HOUSENUMBER
321                             or len(tlist.tokens[0].lookup_word) > 4):
322                         repl.add_penalty(0.39)
323             elif (tlist.ttype == qmod.TokenType.HOUSENUMBER
324                   and len(tlist.tokens[0].lookup_word) <= 3):
325                 if any(c.isdigit() for c in tlist.tokens[0].lookup_word):
326                     for repl in node.starting:
327                         if repl.end == tlist.end and repl.ttype != qmod.TokenType.HOUSENUMBER:
328                             repl.add_penalty(0.5 - tlist.tokens[0].penalty)
329             elif tlist.ttype not in (qmod.TokenType.COUNTRY, qmod.TokenType.PARTIAL):
330                 norm = parts[i].normalized
331                 for j in range(i + 1, tlist.end):
332                     if parts[j - 1].word_number != parts[j].word_number:
333                         norm += '  ' + parts[j].normalized
334                 for token in tlist.tokens:
335                     cast(ICUToken, token).rematch(norm)
336
337
338 def _dump_transliterated(query: qmod.QueryStruct, parts: QueryParts) -> str:
339     out = query.nodes[0].btype.value
340     for node, part in zip(query.nodes[1:], parts):
341         out += part.token + node.btype.value
342     return out
343
344
345 def _dump_word_tokens(query: qmod.QueryStruct) -> Iterator[List[Any]]:
346     yield ['type', 'token', 'word_token', 'lookup_word', 'penalty', 'count', 'info']
347     for node in query.nodes:
348         for tlist in node.starting:
349             for token in tlist.tokens:
350                 t = cast(ICUToken, token)
351                 yield [tlist.ttype.name, t.token, t.word_token or '',
352                        t.lookup_word or '', t.penalty, t.count, t.info]
353
354
355 async def create_query_analyzer(conn: SearchConnection) -> AbstractQueryAnalyzer:
356     """ Create and set up a new query analyzer for a database based
357         on the ICU tokenizer.
358     """
359     out = ICUQueryAnalyzer(conn)
360     await out.setup()
361
362     return out