]> git.openstreetmap.org Git - nominatim.git/blob - src/nominatim_api/search/icu_tokenizer.py
release 5.0.0.post5
[nominatim.git] / src / nominatim_api / search / icu_tokenizer.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2024 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Implementation of query analysis for the ICU tokenizer.
9 """
10 from typing import Tuple, Dict, List, Optional, Iterator, Any, cast
11 from collections import defaultdict
12 import dataclasses
13 import difflib
14 import re
15 from itertools import zip_longest
16
17 from icu import Transliterator
18
19 import sqlalchemy as sa
20
21 from ..errors import UsageError
22 from ..typing import SaRow
23 from ..sql.sqlalchemy_types import Json
24 from ..connection import SearchConnection
25 from ..logging import log
26 from . import query as qmod
27 from ..query_preprocessing.config import QueryConfig
28 from .query_analyzer_factory import AbstractQueryAnalyzer
29
30
31 DB_TO_TOKEN_TYPE = {
32     'W': qmod.TOKEN_WORD,
33     'w': qmod.TOKEN_PARTIAL,
34     'H': qmod.TOKEN_HOUSENUMBER,
35     'P': qmod.TOKEN_POSTCODE,
36     'C': qmod.TOKEN_COUNTRY
37 }
38
39 PENALTY_IN_TOKEN_BREAK = {
40      qmod.BREAK_START: 0.5,
41      qmod.BREAK_END: 0.5,
42      qmod.BREAK_PHRASE: 0.5,
43      qmod.BREAK_SOFT_PHRASE: 0.5,
44      qmod.BREAK_WORD: 0.1,
45      qmod.BREAK_PART: 0.0,
46      qmod.BREAK_TOKEN: 0.0
47 }
48
49
50 @dataclasses.dataclass
51 class QueryPart:
52     """ Normalized and transliterated form of a single term in the query.
53
54         When the term came out of a split during the transliteration,
55         the normalized string is the full word before transliteration.
56         Check the subsequent break type to figure out if the word is
57         continued.
58
59         Penalty is the break penalty for the break following the token.
60     """
61     token: str
62     normalized: str
63     penalty: float
64
65
66 QueryParts = List[QueryPart]
67 WordDict = Dict[str, List[qmod.TokenRange]]
68
69
70 def extract_words(terms: List[QueryPart], start: int,  words: WordDict) -> None:
71     """ Add all combinations of words in the terms list after the
72         given position to the word list.
73     """
74     total = len(terms)
75     base_penalty = PENALTY_IN_TOKEN_BREAK[qmod.BREAK_WORD]
76     for first in range(start, total):
77         word = terms[first].token
78         penalty = base_penalty
79         words[word].append(qmod.TokenRange(first, first + 1, penalty=penalty))
80         for last in range(first + 1, min(first + 20, total)):
81             word = ' '.join((word, terms[last].token))
82             penalty += terms[last - 1].penalty
83             words[word].append(qmod.TokenRange(first, last + 1, penalty=penalty))
84
85
86 @dataclasses.dataclass
87 class ICUToken(qmod.Token):
88     """ Specialised token for ICU tokenizer.
89     """
90     word_token: str
91     info: Optional[Dict[str, Any]]
92
93     def get_category(self) -> Tuple[str, str]:
94         assert self.info
95         return self.info.get('class', ''), self.info.get('type', '')
96
97     def rematch(self, norm: str) -> None:
98         """ Check how well the token matches the given normalized string
99             and add a penalty, if necessary.
100         """
101         if not self.lookup_word:
102             return
103
104         seq = difflib.SequenceMatcher(a=self.lookup_word, b=norm)
105         distance = 0
106         for tag, afrom, ato, bfrom, bto in seq.get_opcodes():
107             if tag in ('delete', 'insert') and (afrom == 0 or ato == len(self.lookup_word)):
108                 distance += 1
109             elif tag == 'replace':
110                 distance += max((ato-afrom), (bto-bfrom))
111             elif tag != 'equal':
112                 distance += abs((ato-afrom) - (bto-bfrom))
113         self.penalty += (distance/len(self.lookup_word))
114
115     @staticmethod
116     def from_db_row(row: SaRow, base_penalty: float = 0.0) -> 'ICUToken':
117         """ Create a ICUToken from the row of the word table.
118         """
119         count = 1 if row.info is None else row.info.get('count', 1)
120         addr_count = 1 if row.info is None else row.info.get('addr_count', 1)
121
122         penalty = base_penalty
123         if row.type == 'w':
124             penalty += 0.3
125         elif row.type == 'W':
126             if len(row.word_token) == 1 and row.word_token == row.word:
127                 penalty += 0.2 if row.word.isdigit() else 0.3
128         elif row.type == 'H':
129             penalty += sum(0.1 for c in row.word_token if c != ' ' and not c.isdigit())
130             if all(not c.isdigit() for c in row.word_token):
131                 penalty += 0.2 * (len(row.word_token) - 1)
132         elif row.type == 'C':
133             if len(row.word_token) == 1:
134                 penalty += 0.3
135
136         if row.info is None:
137             lookup_word = row.word
138         else:
139             lookup_word = row.info.get('lookup', row.word)
140         if lookup_word:
141             lookup_word = lookup_word.split('@', 1)[0]
142         else:
143             lookup_word = row.word_token
144
145         return ICUToken(penalty=penalty, token=row.word_id, count=max(1, count),
146                         lookup_word=lookup_word,
147                         word_token=row.word_token, info=row.info,
148                         addr_count=max(1, addr_count))
149
150
151 class ICUQueryAnalyzer(AbstractQueryAnalyzer):
152     """ Converter for query strings into a tokenized query
153         using the tokens created by a ICU tokenizer.
154     """
155     def __init__(self, conn: SearchConnection) -> None:
156         self.conn = conn
157
158     async def setup(self) -> None:
159         """ Set up static data structures needed for the analysis.
160         """
161         async def _make_normalizer() -> Any:
162             rules = await self.conn.get_property('tokenizer_import_normalisation')
163             return Transliterator.createFromRules("normalization", rules)
164
165         self.normalizer = await self.conn.get_cached_value('ICUTOK', 'normalizer',
166                                                            _make_normalizer)
167
168         async def _make_transliterator() -> Any:
169             rules = await self.conn.get_property('tokenizer_import_transliteration')
170             return Transliterator.createFromRules("transliteration", rules)
171
172         self.transliterator = await self.conn.get_cached_value('ICUTOK', 'transliterator',
173                                                                _make_transliterator)
174
175         await self._setup_preprocessing()
176
177         if 'word' not in self.conn.t.meta.tables:
178             sa.Table('word', self.conn.t.meta,
179                      sa.Column('word_id', sa.Integer),
180                      sa.Column('word_token', sa.Text, nullable=False),
181                      sa.Column('type', sa.Text, nullable=False),
182                      sa.Column('word', sa.Text),
183                      sa.Column('info', Json))
184
185     async def _setup_preprocessing(self) -> None:
186         """ Load the rules for preprocessing and set up the handlers.
187         """
188
189         rules = self.conn.config.load_sub_configuration('icu_tokenizer.yaml',
190                                                         config='TOKENIZER_CONFIG')
191         preprocessing_rules = rules.get('query-preprocessing', [])
192
193         self.preprocessors = []
194
195         for func in preprocessing_rules:
196             if 'step' not in func:
197                 raise UsageError("Preprocessing rule is missing the 'step' attribute.")
198             if not isinstance(func['step'], str):
199                 raise UsageError("'step' attribute must be a simple string.")
200
201             module = self.conn.config.load_plugin_module(
202                         func['step'], 'nominatim_api.query_preprocessing')
203             self.preprocessors.append(
204                 module.create(QueryConfig(func).set_normalizer(self.normalizer)))
205
206     async def analyze_query(self, phrases: List[qmod.Phrase]) -> qmod.QueryStruct:
207         """ Analyze the given list of phrases and return the
208             tokenized query.
209         """
210         log().section('Analyze query (using ICU tokenizer)')
211         for func in self.preprocessors:
212             phrases = func(phrases)
213
214         if len(phrases) == 1 \
215                 and phrases[0].text.count(' ') > 3 \
216                 and max(len(s) for s in phrases[0].text.split()) < 3:
217             normalized = []
218
219         query = qmod.QueryStruct(phrases)
220
221         log().var_dump('Normalized query', query.source)
222         if not query.source:
223             return query
224
225         parts, words = self.split_query(query)
226         log().var_dump('Transliterated query', lambda: _dump_transliterated(query, parts))
227
228         for row in await self.lookup_in_db(list(words.keys())):
229             for trange in words[row.word_token]:
230                 token = ICUToken.from_db_row(row, trange.penalty or 0.0)
231                 if row.type == 'S':
232                     if row.info['op'] in ('in', 'near'):
233                         if trange.start == 0:
234                             query.add_token(trange, qmod.TOKEN_NEAR_ITEM, token)
235                     else:
236                         if trange.start == 0 and trange.end == query.num_token_slots():
237                             query.add_token(trange, qmod.TOKEN_NEAR_ITEM, token)
238                         else:
239                             query.add_token(trange, qmod.TOKEN_QUALIFIER, token)
240                 else:
241                     query.add_token(trange, DB_TO_TOKEN_TYPE[row.type], token)
242
243         self.add_extra_tokens(query, parts)
244         self.rerank_tokens(query, parts)
245
246         log().table_dump('Word tokens', _dump_word_tokens(query))
247
248         return query
249
250     def normalize_text(self, text: str) -> str:
251         """ Bring the given text into a normalized form. That is the
252             standardized form search will work with. All information removed
253             at this stage is inevitably lost.
254         """
255         return cast(str, self.normalizer.transliterate(text)).strip('-: ')
256
257     def split_query(self, query: qmod.QueryStruct) -> Tuple[QueryParts, WordDict]:
258         """ Transliterate the phrases and split them into tokens.
259
260             Returns the list of transliterated tokens together with their
261             normalized form and a dictionary of words for lookup together
262             with their position.
263         """
264         parts: QueryParts = []
265         phrase_start = 0
266         words: WordDict = defaultdict(list)
267         for phrase in query.source:
268             query.nodes[-1].ptype = phrase.ptype
269             phrase_split = re.split('([ :-])', phrase.text)
270             # The zip construct will give us the pairs of word/break from
271             # the regular expression split. As the split array ends on the
272             # final word, we simply use the fillvalue to even out the list and
273             # add the phrase break at the end.
274             for word, breakchar in zip_longest(*[iter(phrase_split)]*2, fillvalue=','):
275                 if not word:
276                     continue
277                 trans = self.transliterator.transliterate(word)
278                 if trans:
279                     for term in trans.split(' '):
280                         if term:
281                             parts.append(QueryPart(term, word,
282                                                    PENALTY_IN_TOKEN_BREAK[qmod.BREAK_TOKEN]))
283                             query.add_node(qmod.BREAK_TOKEN, phrase.ptype)
284                     query.nodes[-1].btype = breakchar
285                     parts[-1].penalty = PENALTY_IN_TOKEN_BREAK[breakchar]
286
287             extract_words(parts, phrase_start, words)
288
289             phrase_start = len(parts)
290         query.nodes[-1].btype = qmod.BREAK_END
291
292         return parts, words
293
294     async def lookup_in_db(self, words: List[str]) -> 'sa.Result[Any]':
295         """ Return the token information from the database for the
296             given word tokens.
297         """
298         t = self.conn.t.meta.tables['word']
299         return await self.conn.execute(t.select().where(t.c.word_token.in_(words)))
300
301     def add_extra_tokens(self, query: qmod.QueryStruct, parts: QueryParts) -> None:
302         """ Add tokens to query that are not saved in the database.
303         """
304         for part, node, i in zip(parts, query.nodes, range(1000)):
305             if len(part.token) <= 4 and part.token.isdigit()\
306                and not node.has_tokens(i+1, qmod.TOKEN_HOUSENUMBER):
307                 query.add_token(qmod.TokenRange(i, i+1), qmod.TOKEN_HOUSENUMBER,
308                                 ICUToken(penalty=0.5, token=0,
309                                          count=1, addr_count=1, lookup_word=part.token,
310                                          word_token=part.token, info=None))
311
312     def rerank_tokens(self, query: qmod.QueryStruct, parts: QueryParts) -> None:
313         """ Add penalties to tokens that depend on presence of other token.
314         """
315         for i, node, tlist in query.iter_token_lists():
316             if tlist.ttype == qmod.TOKEN_POSTCODE:
317                 for repl in node.starting:
318                     if repl.end == tlist.end and repl.ttype != qmod.TOKEN_POSTCODE \
319                        and (repl.ttype != qmod.TOKEN_HOUSENUMBER
320                             or len(tlist.tokens[0].lookup_word) > 4):
321                         repl.add_penalty(0.39)
322             elif (tlist.ttype == qmod.TOKEN_HOUSENUMBER
323                   and len(tlist.tokens[0].lookup_word) <= 3):
324                 if any(c.isdigit() for c in tlist.tokens[0].lookup_word):
325                     for repl in node.starting:
326                         if repl.end == tlist.end and repl.ttype != qmod.TOKEN_HOUSENUMBER:
327                             repl.add_penalty(0.5 - tlist.tokens[0].penalty)
328             elif tlist.ttype not in (qmod.TOKEN_COUNTRY, qmod.TOKEN_PARTIAL):
329                 norm = parts[i].normalized
330                 for j in range(i + 1, tlist.end):
331                     if node.btype != qmod.BREAK_TOKEN:
332                         norm += '  ' + parts[j].normalized
333                 for token in tlist.tokens:
334                     cast(ICUToken, token).rematch(norm)
335
336
337 def _dump_transliterated(query: qmod.QueryStruct, parts: QueryParts) -> str:
338     out = query.nodes[0].btype
339     for node, part in zip(query.nodes[1:], parts):
340         out += part.token + node.btype
341     return out
342
343
344 def _dump_word_tokens(query: qmod.QueryStruct) -> Iterator[List[Any]]:
345     yield ['type', 'token', 'word_token', 'lookup_word', 'penalty', 'count', 'info']
346     for node in query.nodes:
347         for tlist in node.starting:
348             for token in tlist.tokens:
349                 t = cast(ICUToken, token)
350                 yield [tlist.ttype, t.token, t.word_token or '',
351                        t.lookup_word or '', t.penalty, t.count, t.info]
352
353
354 async def create_query_analyzer(conn: SearchConnection) -> AbstractQueryAnalyzer:
355     """ Create and set up a new query analyzer for a database based
356         on the ICU tokenizer.
357     """
358     out = ICUQueryAnalyzer(conn)
359     await out.setup()
360
361     return out