]> git.openstreetmap.org Git - nominatim.git/blob - src/nominatim_api/search/icu_tokenizer.py
search: merge QueryPart array with QueryNodes
[nominatim.git] / src / nominatim_api / search / icu_tokenizer.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2024 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Implementation of query analysis for the ICU tokenizer.
9 """
10 from typing import Tuple, Dict, List, Optional, Iterator, Any, cast
11 from collections import defaultdict
12 import dataclasses
13 import difflib
14 import re
15 from itertools import zip_longest
16
17 from icu import Transliterator
18
19 import sqlalchemy as sa
20
21 from ..errors import UsageError
22 from ..typing import SaRow
23 from ..sql.sqlalchemy_types import Json
24 from ..connection import SearchConnection
25 from ..logging import log
26 from . import query as qmod
27 from ..query_preprocessing.config import QueryConfig
28 from .query_analyzer_factory import AbstractQueryAnalyzer
29
30
31 DB_TO_TOKEN_TYPE = {
32     'W': qmod.TOKEN_WORD,
33     'w': qmod.TOKEN_PARTIAL,
34     'H': qmod.TOKEN_HOUSENUMBER,
35     'P': qmod.TOKEN_POSTCODE,
36     'C': qmod.TOKEN_COUNTRY
37 }
38
39 PENALTY_IN_TOKEN_BREAK = {
40      qmod.BREAK_START: 0.5,
41      qmod.BREAK_END: 0.5,
42      qmod.BREAK_PHRASE: 0.5,
43      qmod.BREAK_SOFT_PHRASE: 0.5,
44      qmod.BREAK_WORD: 0.1,
45      qmod.BREAK_PART: 0.0,
46      qmod.BREAK_TOKEN: 0.0
47 }
48
49
50 WordDict = Dict[str, List[qmod.TokenRange]]
51
52
53 def extract_words(query: qmod.QueryStruct, start: int,  words: WordDict) -> None:
54     """ Add all combinations of words in the terms list starting with
55         the term leading into node 'start'.
56
57         The words found will be added into the 'words' dictionary with
58         their start and end position.
59     """
60     nodes = query.nodes
61     total = len(nodes)
62     base_penalty = PENALTY_IN_TOKEN_BREAK[qmod.BREAK_WORD]
63     for first in range(start, total):
64         word = nodes[first].term_lookup
65         penalty = base_penalty
66         words[word].append(qmod.TokenRange(first - 1, first, penalty=penalty))
67         for last in range(first + 1, min(first + 20, total)):
68             word = ' '.join((word, nodes[last].term_lookup))
69             penalty += nodes[last - 1].penalty
70             words[word].append(qmod.TokenRange(first - 1, last, penalty=penalty))
71
72
73 @dataclasses.dataclass
74 class ICUToken(qmod.Token):
75     """ Specialised token for ICU tokenizer.
76     """
77     word_token: str
78     info: Optional[Dict[str, Any]]
79
80     def get_category(self) -> Tuple[str, str]:
81         assert self.info
82         return self.info.get('class', ''), self.info.get('type', '')
83
84     def rematch(self, norm: str) -> None:
85         """ Check how well the token matches the given normalized string
86             and add a penalty, if necessary.
87         """
88         if not self.lookup_word:
89             return
90
91         seq = difflib.SequenceMatcher(a=self.lookup_word, b=norm)
92         distance = 0
93         for tag, afrom, ato, bfrom, bto in seq.get_opcodes():
94             if tag in ('delete', 'insert') and (afrom == 0 or ato == len(self.lookup_word)):
95                 distance += 1
96             elif tag == 'replace':
97                 distance += max((ato-afrom), (bto-bfrom))
98             elif tag != 'equal':
99                 distance += abs((ato-afrom) - (bto-bfrom))
100         self.penalty += (distance/len(self.lookup_word))
101
102     @staticmethod
103     def from_db_row(row: SaRow, base_penalty: float = 0.0) -> 'ICUToken':
104         """ Create a ICUToken from the row of the word table.
105         """
106         count = 1 if row.info is None else row.info.get('count', 1)
107         addr_count = 1 if row.info is None else row.info.get('addr_count', 1)
108
109         penalty = base_penalty
110         if row.type == 'w':
111             penalty += 0.3
112         elif row.type == 'W':
113             if len(row.word_token) == 1 and row.word_token == row.word:
114                 penalty += 0.2 if row.word.isdigit() else 0.3
115         elif row.type == 'H':
116             penalty += sum(0.1 for c in row.word_token if c != ' ' and not c.isdigit())
117             if all(not c.isdigit() for c in row.word_token):
118                 penalty += 0.2 * (len(row.word_token) - 1)
119         elif row.type == 'C':
120             if len(row.word_token) == 1:
121                 penalty += 0.3
122
123         if row.info is None:
124             lookup_word = row.word
125         else:
126             lookup_word = row.info.get('lookup', row.word)
127         if lookup_word:
128             lookup_word = lookup_word.split('@', 1)[0]
129         else:
130             lookup_word = row.word_token
131
132         return ICUToken(penalty=penalty, token=row.word_id, count=max(1, count),
133                         lookup_word=lookup_word,
134                         word_token=row.word_token, info=row.info,
135                         addr_count=max(1, addr_count))
136
137
138 class ICUQueryAnalyzer(AbstractQueryAnalyzer):
139     """ Converter for query strings into a tokenized query
140         using the tokens created by a ICU tokenizer.
141     """
142     def __init__(self, conn: SearchConnection) -> None:
143         self.conn = conn
144
145     async def setup(self) -> None:
146         """ Set up static data structures needed for the analysis.
147         """
148         async def _make_normalizer() -> Any:
149             rules = await self.conn.get_property('tokenizer_import_normalisation')
150             return Transliterator.createFromRules("normalization", rules)
151
152         self.normalizer = await self.conn.get_cached_value('ICUTOK', 'normalizer',
153                                                            _make_normalizer)
154
155         async def _make_transliterator() -> Any:
156             rules = await self.conn.get_property('tokenizer_import_transliteration')
157             return Transliterator.createFromRules("transliteration", rules)
158
159         self.transliterator = await self.conn.get_cached_value('ICUTOK', 'transliterator',
160                                                                _make_transliterator)
161
162         await self._setup_preprocessing()
163
164         if 'word' not in self.conn.t.meta.tables:
165             sa.Table('word', self.conn.t.meta,
166                      sa.Column('word_id', sa.Integer),
167                      sa.Column('word_token', sa.Text, nullable=False),
168                      sa.Column('type', sa.Text, nullable=False),
169                      sa.Column('word', sa.Text),
170                      sa.Column('info', Json))
171
172     async def _setup_preprocessing(self) -> None:
173         """ Load the rules for preprocessing and set up the handlers.
174         """
175
176         rules = self.conn.config.load_sub_configuration('icu_tokenizer.yaml',
177                                                         config='TOKENIZER_CONFIG')
178         preprocessing_rules = rules.get('query-preprocessing', [])
179
180         self.preprocessors = []
181
182         for func in preprocessing_rules:
183             if 'step' not in func:
184                 raise UsageError("Preprocessing rule is missing the 'step' attribute.")
185             if not isinstance(func['step'], str):
186                 raise UsageError("'step' attribute must be a simple string.")
187
188             module = self.conn.config.load_plugin_module(
189                         func['step'], 'nominatim_api.query_preprocessing')
190             self.preprocessors.append(
191                 module.create(QueryConfig(func).set_normalizer(self.normalizer)))
192
193     async def analyze_query(self, phrases: List[qmod.Phrase]) -> qmod.QueryStruct:
194         """ Analyze the given list of phrases and return the
195             tokenized query.
196         """
197         log().section('Analyze query (using ICU tokenizer)')
198         for func in self.preprocessors:
199             phrases = func(phrases)
200         query = qmod.QueryStruct(phrases)
201
202         log().var_dump('Normalized query', query.source)
203         if not query.source:
204             return query
205
206         words = self.split_query(query)
207         log().var_dump('Transliterated query', lambda: query.get_transliterated_query())
208
209         for row in await self.lookup_in_db(list(words.keys())):
210             for trange in words[row.word_token]:
211                 token = ICUToken.from_db_row(row, trange.penalty or 0.0)
212                 if row.type == 'S':
213                     if row.info['op'] in ('in', 'near'):
214                         if trange.start == 0:
215                             query.add_token(trange, qmod.TOKEN_NEAR_ITEM, token)
216                     else:
217                         if trange.start == 0 and trange.end == query.num_token_slots():
218                             query.add_token(trange, qmod.TOKEN_NEAR_ITEM, token)
219                         else:
220                             query.add_token(trange, qmod.TOKEN_QUALIFIER, token)
221                 else:
222                     query.add_token(trange, DB_TO_TOKEN_TYPE[row.type], token)
223
224         self.add_extra_tokens(query)
225         self.rerank_tokens(query)
226
227         log().table_dump('Word tokens', _dump_word_tokens(query))
228
229         return query
230
231     def normalize_text(self, text: str) -> str:
232         """ Bring the given text into a normalized form. That is the
233             standardized form search will work with. All information removed
234             at this stage is inevitably lost.
235         """
236         return cast(str, self.normalizer.transliterate(text)).strip('-: ')
237
238     def split_query(self, query: qmod.QueryStruct) -> WordDict:
239         """ Transliterate the phrases and split them into tokens.
240
241             Returns a dictionary of words for lookup together
242             with their position.
243         """
244         phrase_start = 1
245         words: WordDict = defaultdict(list)
246         for phrase in query.source:
247             query.nodes[-1].ptype = phrase.ptype
248             phrase_split = re.split('([ :-])', phrase.text)
249             # The zip construct will give us the pairs of word/break from
250             # the regular expression split. As the split array ends on the
251             # final word, we simply use the fillvalue to even out the list and
252             # add the phrase break at the end.
253             for word, breakchar in zip_longest(*[iter(phrase_split)]*2, fillvalue=','):
254                 if not word:
255                     continue
256                 trans = self.transliterator.transliterate(word)
257                 if trans:
258                     for term in trans.split(' '):
259                         if term:
260                             query.add_node(qmod.BREAK_TOKEN, phrase.ptype,
261                                            PENALTY_IN_TOKEN_BREAK[qmod.BREAK_TOKEN],
262                                            term, word)
263                     query.nodes[-1].adjust_break(breakchar,
264                                                  PENALTY_IN_TOKEN_BREAK[breakchar])
265
266             extract_words(query, phrase_start, words)
267
268             phrase_start = len(query.nodes)
269         query.nodes[-1].adjust_break(qmod.BREAK_END, PENALTY_IN_TOKEN_BREAK[qmod.BREAK_END])
270
271         return words
272
273     async def lookup_in_db(self, words: List[str]) -> 'sa.Result[Any]':
274         """ Return the token information from the database for the
275             given word tokens.
276         """
277         t = self.conn.t.meta.tables['word']
278         return await self.conn.execute(t.select().where(t.c.word_token.in_(words)))
279
280     def add_extra_tokens(self, query: qmod.QueryStruct) -> None:
281         """ Add tokens to query that are not saved in the database.
282         """
283         need_hnr = False
284         for i, node in enumerate(query.nodes):
285             is_full_token = node.btype not in (qmod.BREAK_TOKEN, qmod.BREAK_PART)
286             if need_hnr and is_full_token \
287                     and len(node.term_normalized) <= 4 and node.term_normalized.isdigit():
288                 query.add_token(qmod.TokenRange(i-1, i), qmod.TOKEN_HOUSENUMBER,
289                                 ICUToken(penalty=0.5, token=0,
290                                          count=1, addr_count=1,
291                                          lookup_word=node.term_lookup,
292                                          word_token=node.term_lookup, info=None))
293
294             need_hnr = is_full_token and not node.has_tokens(i+1, qmod.TOKEN_HOUSENUMBER)
295
296     def rerank_tokens(self, query: qmod.QueryStruct) -> None:
297         """ Add penalties to tokens that depend on presence of other token.
298         """
299         for i, node, tlist in query.iter_token_lists():
300             if tlist.ttype == qmod.TOKEN_POSTCODE:
301                 for repl in node.starting:
302                     if repl.end == tlist.end and repl.ttype != qmod.TOKEN_POSTCODE \
303                        and (repl.ttype != qmod.TOKEN_HOUSENUMBER
304                             or len(tlist.tokens[0].lookup_word) > 4):
305                         repl.add_penalty(0.39)
306             elif (tlist.ttype == qmod.TOKEN_HOUSENUMBER
307                   and len(tlist.tokens[0].lookup_word) <= 3):
308                 if any(c.isdigit() for c in tlist.tokens[0].lookup_word):
309                     for repl in node.starting:
310                         if repl.end == tlist.end and repl.ttype != qmod.TOKEN_HOUSENUMBER:
311                             repl.add_penalty(0.5 - tlist.tokens[0].penalty)
312             elif tlist.ttype not in (qmod.TOKEN_COUNTRY, qmod.TOKEN_PARTIAL):
313                 norm = ' '.join(n.term_normalized for n in query.nodes[i + 1:tlist.end + 1]
314                                 if n.btype != qmod.BREAK_TOKEN)
315                 if not norm:
316                     # Can happen when the token only covers a partial term
317                     norm = query.nodes[i + 1].term_normalized
318                 for token in tlist.tokens:
319                     cast(ICUToken, token).rematch(norm)
320
321
322 def _dump_word_tokens(query: qmod.QueryStruct) -> Iterator[List[Any]]:
323     yield ['type', 'token', 'word_token', 'lookup_word', 'penalty', 'count', 'info']
324     for node in query.nodes:
325         for tlist in node.starting:
326             for token in tlist.tokens:
327                 t = cast(ICUToken, token)
328                 yield [tlist.ttype, t.token, t.word_token or '',
329                        t.lookup_word or '', t.penalty, t.count, t.info]
330
331
332 async def create_query_analyzer(conn: SearchConnection) -> AbstractQueryAnalyzer:
333     """ Create and set up a new query analyzer for a database based
334         on the ICU tokenizer.
335     """
336     out = ICUQueryAnalyzer(conn)
337     await out.setup()
338
339     return out