]> git.openstreetmap.org Git - nominatim.git/blob - src/nominatim_api/search/icu_tokenizer.py
5976fbec05d8c515dfff092606942b62e602aaac
[nominatim.git] / src / nominatim_api / search / icu_tokenizer.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2024 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Implementation of query analysis for the ICU tokenizer.
9 """
10 from typing import Tuple, Dict, List, Optional, NamedTuple, Iterator, Any, cast
11 from collections import defaultdict
12 import dataclasses
13 import difflib
14
15 from icu import Transliterator
16
17 import sqlalchemy as sa
18
19 from ..errors import UsageError
20 from ..typing import SaRow
21 from ..sql.sqlalchemy_types import Json
22 from ..connection import SearchConnection
23 from ..logging import log
24 from . import query as qmod
25 from ..query_preprocessing.config import QueryConfig
26 from .query_analyzer_factory import AbstractQueryAnalyzer
27
28
29 DB_TO_TOKEN_TYPE = {
30     'W': qmod.TokenType.WORD,
31     'w': qmod.TokenType.PARTIAL,
32     'H': qmod.TokenType.HOUSENUMBER,
33     'P': qmod.TokenType.POSTCODE,
34     'C': qmod.TokenType.COUNTRY
35 }
36
37
38 class QueryPart(NamedTuple):
39     """ Normalized and transliterated form of a single term in the query.
40         When the term came out of a split during the transliteration,
41         the normalized string is the full word before transliteration.
42         The word number keeps track of the word before transliteration
43         and can be used to identify partial transliterated terms.
44     """
45     token: str
46     normalized: str
47     word_number: int
48
49
50 QueryParts = List[QueryPart]
51 WordDict = Dict[str, List[qmod.TokenRange]]
52
53
54 def yield_words(terms: List[QueryPart], start: int) -> Iterator[Tuple[str, qmod.TokenRange]]:
55     """ Return all combinations of words in the terms list after the
56         given position.
57     """
58     total = len(terms)
59     for first in range(start, total):
60         word = terms[first].token
61         yield word, qmod.TokenRange(first, first + 1)
62         for last in range(first + 1, min(first + 20, total)):
63             word = ' '.join((word, terms[last].token))
64             yield word, qmod.TokenRange(first, last + 1)
65
66
67 @dataclasses.dataclass
68 class ICUToken(qmod.Token):
69     """ Specialised token for ICU tokenizer.
70     """
71     word_token: str
72     info: Optional[Dict[str, Any]]
73
74     def get_category(self) -> Tuple[str, str]:
75         assert self.info
76         return self.info.get('class', ''), self.info.get('type', '')
77
78     def rematch(self, norm: str) -> None:
79         """ Check how well the token matches the given normalized string
80             and add a penalty, if necessary.
81         """
82         if not self.lookup_word:
83             return
84
85         seq = difflib.SequenceMatcher(a=self.lookup_word, b=norm)
86         distance = 0
87         for tag, afrom, ato, bfrom, bto in seq.get_opcodes():
88             if tag in ('delete', 'insert') and (afrom == 0 or ato == len(self.lookup_word)):
89                 distance += 1
90             elif tag == 'replace':
91                 distance += max((ato-afrom), (bto-bfrom))
92             elif tag != 'equal':
93                 distance += abs((ato-afrom) - (bto-bfrom))
94         self.penalty += (distance/len(self.lookup_word))
95
96     @staticmethod
97     def from_db_row(row: SaRow) -> 'ICUToken':
98         """ Create a ICUToken from the row of the word table.
99         """
100         count = 1 if row.info is None else row.info.get('count', 1)
101         addr_count = 1 if row.info is None else row.info.get('addr_count', 1)
102
103         penalty = 0.0
104         if row.type == 'w':
105             penalty = 0.3
106         elif row.type == 'W':
107             if len(row.word_token) == 1 and row.word_token == row.word:
108                 penalty = 0.2 if row.word.isdigit() else 0.3
109         elif row.type == 'H':
110             penalty = sum(0.1 for c in row.word_token if c != ' ' and not c.isdigit())
111             if all(not c.isdigit() for c in row.word_token):
112                 penalty += 0.2 * (len(row.word_token) - 1)
113         elif row.type == 'C':
114             if len(row.word_token) == 1:
115                 penalty = 0.3
116
117         if row.info is None:
118             lookup_word = row.word
119         else:
120             lookup_word = row.info.get('lookup', row.word)
121         if lookup_word:
122             lookup_word = lookup_word.split('@', 1)[0]
123         else:
124             lookup_word = row.word_token
125
126         return ICUToken(penalty=penalty, token=row.word_id, count=max(1, count),
127                         lookup_word=lookup_word,
128                         word_token=row.word_token, info=row.info,
129                         addr_count=max(1, addr_count))
130
131
132 class ICUQueryAnalyzer(AbstractQueryAnalyzer):
133     """ Converter for query strings into a tokenized query
134         using the tokens created by a ICU tokenizer.
135     """
136     def __init__(self, conn: SearchConnection) -> None:
137         self.conn = conn
138
139     async def setup(self) -> None:
140         """ Set up static data structures needed for the analysis.
141         """
142         async def _make_normalizer() -> Any:
143             rules = await self.conn.get_property('tokenizer_import_normalisation')
144             return Transliterator.createFromRules("normalization", rules)
145
146         self.normalizer = await self.conn.get_cached_value('ICUTOK', 'normalizer',
147                                                            _make_normalizer)
148
149         async def _make_transliterator() -> Any:
150             rules = await self.conn.get_property('tokenizer_import_transliteration')
151             return Transliterator.createFromRules("transliteration", rules)
152
153         self.transliterator = await self.conn.get_cached_value('ICUTOK', 'transliterator',
154                                                                _make_transliterator)
155
156         await self._setup_preprocessing()
157
158         if 'word' not in self.conn.t.meta.tables:
159             sa.Table('word', self.conn.t.meta,
160                      sa.Column('word_id', sa.Integer),
161                      sa.Column('word_token', sa.Text, nullable=False),
162                      sa.Column('type', sa.Text, nullable=False),
163                      sa.Column('word', sa.Text),
164                      sa.Column('info', Json))
165
166     async def _setup_preprocessing(self) -> None:
167         """ Load the rules for preprocessing and set up the handlers.
168         """
169
170         rules = self.conn.config.load_sub_configuration('icu_tokenizer.yaml',
171                                                         config='TOKENIZER_CONFIG')
172         preprocessing_rules = rules.get('query-preprocessing', [])
173
174         self.preprocessors = []
175
176         for func in preprocessing_rules:
177             if 'step' not in func:
178                 raise UsageError("Preprocessing rule is missing the 'step' attribute.")
179             if not isinstance(func['step'], str):
180                 raise UsageError("'step' attribute must be a simple string.")
181
182             module = self.conn.config.load_plugin_module(
183                         func['step'], 'nominatim_api.query_preprocessing')
184             self.preprocessors.append(
185                 module.create(QueryConfig(func).set_normalizer(self.normalizer)))
186
187     async def analyze_query(self, phrases: List[qmod.Phrase]) -> qmod.QueryStruct:
188         """ Analyze the given list of phrases and return the
189             tokenized query.
190         """
191         log().section('Analyze query (using ICU tokenizer)')
192         for func in self.preprocessors:
193             phrases = func(phrases)
194         query = qmod.QueryStruct(phrases)
195
196         log().var_dump('Normalized query', query.source)
197         if not query.source:
198             return query
199
200         parts, words = self.split_query(query)
201         log().var_dump('Transliterated query', lambda: _dump_transliterated(query, parts))
202
203         for row in await self.lookup_in_db(list(words.keys())):
204             for trange in words[row.word_token]:
205                 token = ICUToken.from_db_row(row)
206                 if row.type == 'S':
207                     if row.info['op'] in ('in', 'near'):
208                         if trange.start == 0:
209                             query.add_token(trange, qmod.TokenType.NEAR_ITEM, token)
210                     else:
211                         if trange.start == 0 and trange.end == query.num_token_slots():
212                             query.add_token(trange, qmod.TokenType.NEAR_ITEM, token)
213                         else:
214                             query.add_token(trange, qmod.TokenType.QUALIFIER, token)
215                 else:
216                     query.add_token(trange, DB_TO_TOKEN_TYPE[row.type], token)
217
218         self.add_extra_tokens(query, parts)
219         self.rerank_tokens(query, parts)
220
221         log().table_dump('Word tokens', _dump_word_tokens(query))
222
223         return query
224
225     def normalize_text(self, text: str) -> str:
226         """ Bring the given text into a normalized form. That is the
227             standardized form search will work with. All information removed
228             at this stage is inevitably lost.
229         """
230         return cast(str, self.normalizer.transliterate(text))
231
232     def split_query(self, query: qmod.QueryStruct) -> Tuple[QueryParts, WordDict]:
233         """ Transliterate the phrases and split them into tokens.
234
235             Returns the list of transliterated tokens together with their
236             normalized form and a dictionary of words for lookup together
237             with their position.
238         """
239         parts: QueryParts = []
240         phrase_start = 0
241         words = defaultdict(list)
242         wordnr = 0
243         for phrase in query.source:
244             query.nodes[-1].ptype = phrase.ptype
245             for word in phrase.text.split(' '):
246                 trans = self.transliterator.transliterate(word)
247                 if trans:
248                     for term in trans.split(' '):
249                         if term:
250                             parts.append(QueryPart(term, word, wordnr))
251                             query.add_node(qmod.BreakType.TOKEN, phrase.ptype)
252                     query.nodes[-1].btype = qmod.BreakType.WORD
253                 wordnr += 1
254             query.nodes[-1].btype = qmod.BreakType.PHRASE
255
256             for word, wrange in yield_words(parts, phrase_start):
257                 words[word].append(wrange)
258
259             phrase_start = len(parts)
260         query.nodes[-1].btype = qmod.BreakType.END
261
262         return parts, words
263
264     async def lookup_in_db(self, words: List[str]) -> 'sa.Result[Any]':
265         """ Return the token information from the database for the
266             given word tokens.
267         """
268         t = self.conn.t.meta.tables['word']
269         return await self.conn.execute(t.select().where(t.c.word_token.in_(words)))
270
271     def add_extra_tokens(self, query: qmod.QueryStruct, parts: QueryParts) -> None:
272         """ Add tokens to query that are not saved in the database.
273         """
274         for part, node, i in zip(parts, query.nodes, range(1000)):
275             if len(part.token) <= 4 and part[0].isdigit()\
276                and not node.has_tokens(i+1, qmod.TokenType.HOUSENUMBER):
277                 query.add_token(qmod.TokenRange(i, i+1), qmod.TokenType.HOUSENUMBER,
278                                 ICUToken(penalty=0.5, token=0,
279                                          count=1, addr_count=1, lookup_word=part.token,
280                                          word_token=part.token, info=None))
281
282     def rerank_tokens(self, query: qmod.QueryStruct, parts: QueryParts) -> None:
283         """ Add penalties to tokens that depend on presence of other token.
284         """
285         for i, node, tlist in query.iter_token_lists():
286             if tlist.ttype == qmod.TokenType.POSTCODE:
287                 for repl in node.starting:
288                     if repl.end == tlist.end and repl.ttype != qmod.TokenType.POSTCODE \
289                        and (repl.ttype != qmod.TokenType.HOUSENUMBER
290                             or len(tlist.tokens[0].lookup_word) > 4):
291                         repl.add_penalty(0.39)
292             elif (tlist.ttype == qmod.TokenType.HOUSENUMBER
293                   and len(tlist.tokens[0].lookup_word) <= 3):
294                 if any(c.isdigit() for c in tlist.tokens[0].lookup_word):
295                     for repl in node.starting:
296                         if repl.end == tlist.end and repl.ttype != qmod.TokenType.HOUSENUMBER:
297                             repl.add_penalty(0.5 - tlist.tokens[0].penalty)
298             elif tlist.ttype not in (qmod.TokenType.COUNTRY, qmod.TokenType.PARTIAL):
299                 norm = parts[i].normalized
300                 for j in range(i + 1, tlist.end):
301                     if parts[j - 1].word_number != parts[j].word_number:
302                         norm += '  ' + parts[j].normalized
303                 for token in tlist.tokens:
304                     cast(ICUToken, token).rematch(norm)
305
306
307 def _dump_transliterated(query: qmod.QueryStruct, parts: QueryParts) -> str:
308     out = query.nodes[0].btype.value
309     for node, part in zip(query.nodes[1:], parts):
310         out += part.token + node.btype.value
311     return out
312
313
314 def _dump_word_tokens(query: qmod.QueryStruct) -> Iterator[List[Any]]:
315     yield ['type', 'token', 'word_token', 'lookup_word', 'penalty', 'count', 'info']
316     for node in query.nodes:
317         for tlist in node.starting:
318             for token in tlist.tokens:
319                 t = cast(ICUToken, token)
320                 yield [tlist.ttype.name, t.token, t.word_token or '',
321                        t.lookup_word or '', t.penalty, t.count, t.info]
322
323
324 async def create_query_analyzer(conn: SearchConnection) -> AbstractQueryAnalyzer:
325     """ Create and set up a new query analyzer for a database based
326         on the ICU tokenizer.
327     """
328     out = ICUQueryAnalyzer(conn)
329     await out.setup()
330
331     return out