]> git.openstreetmap.org Git - nominatim.git/blob - src/nominatim_api/search/icu_tokenizer.py
consistently use query module as qmod
[nominatim.git] / src / nominatim_api / search / icu_tokenizer.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2024 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Implementation of query analysis for the ICU tokenizer.
9 """
10 from typing import Tuple, Dict, List, Optional, Iterator, Any, cast
11 from collections import defaultdict
12 import dataclasses
13 import difflib
14 import re
15 from itertools import zip_longest
16
17 from icu import Transliterator
18
19 import sqlalchemy as sa
20
21 from ..errors import UsageError
22 from ..typing import SaRow
23 from ..sql.sqlalchemy_types import Json
24 from ..connection import SearchConnection
25 from ..logging import log
26 from . import query as qmod
27 from ..query_preprocessing.config import QueryConfig
28 from .query_analyzer_factory import AbstractQueryAnalyzer
29
30
31 DB_TO_TOKEN_TYPE = {
32     'W': qmod.TokenType.WORD,
33     'w': qmod.TokenType.PARTIAL,
34     'H': qmod.TokenType.HOUSENUMBER,
35     'P': qmod.TokenType.POSTCODE,
36     'C': qmod.TokenType.COUNTRY
37 }
38
39 PENALTY_IN_TOKEN_BREAK = {
40      qmod.BreakType.START: 0.5,
41      qmod.BreakType.END: 0.5,
42      qmod.BreakType.PHRASE: 0.5,
43      qmod.BreakType.SOFT_PHRASE: 0.5,
44      qmod.BreakType.WORD: 0.1,
45      qmod.BreakType.PART: 0.0,
46      qmod.BreakType.TOKEN: 0.0
47 }
48
49
50 @dataclasses.dataclass
51 class QueryPart:
52     """ Normalized and transliterated form of a single term in the query.
53
54         When the term came out of a split during the transliteration,
55         the normalized string is the full word before transliteration.
56         Check the subsequent break type to figure out if the word is
57         continued.
58
59         Penalty is the break penalty for the break following the token.
60     """
61     token: str
62     normalized: str
63     penalty: float
64
65
66 QueryParts = List[QueryPart]
67 WordDict = Dict[str, List[qmod.TokenRange]]
68
69
70 def extract_words(terms: List[QueryPart], start: int,  words: WordDict) -> None:
71     """ Add all combinations of words in the terms list after the
72         given position to the word list.
73     """
74     total = len(terms)
75     base_penalty = PENALTY_IN_TOKEN_BREAK[qmod.BreakType.WORD]
76     for first in range(start, total):
77         word = terms[first].token
78         penalty = base_penalty
79         words[word].append(qmod.TokenRange(first, first + 1, penalty=penalty))
80         for last in range(first + 1, min(first + 20, total)):
81             word = ' '.join((word, terms[last].token))
82             penalty += terms[last - 1].penalty
83             words[word].append(qmod.TokenRange(first, last + 1, penalty=penalty))
84
85
86 @dataclasses.dataclass
87 class ICUToken(qmod.Token):
88     """ Specialised token for ICU tokenizer.
89     """
90     word_token: str
91     info: Optional[Dict[str, Any]]
92
93     def get_category(self) -> Tuple[str, str]:
94         assert self.info
95         return self.info.get('class', ''), self.info.get('type', '')
96
97     def rematch(self, norm: str) -> None:
98         """ Check how well the token matches the given normalized string
99             and add a penalty, if necessary.
100         """
101         if not self.lookup_word:
102             return
103
104         seq = difflib.SequenceMatcher(a=self.lookup_word, b=norm)
105         distance = 0
106         for tag, afrom, ato, bfrom, bto in seq.get_opcodes():
107             if tag in ('delete', 'insert') and (afrom == 0 or ato == len(self.lookup_word)):
108                 distance += 1
109             elif tag == 'replace':
110                 distance += max((ato-afrom), (bto-bfrom))
111             elif tag != 'equal':
112                 distance += abs((ato-afrom) - (bto-bfrom))
113         self.penalty += (distance/len(self.lookup_word))
114
115     @staticmethod
116     def from_db_row(row: SaRow, base_penalty: float = 0.0) -> 'ICUToken':
117         """ Create a ICUToken from the row of the word table.
118         """
119         count = 1 if row.info is None else row.info.get('count', 1)
120         addr_count = 1 if row.info is None else row.info.get('addr_count', 1)
121
122         penalty = base_penalty
123         if row.type == 'w':
124             penalty += 0.3
125         elif row.type == 'W':
126             if len(row.word_token) == 1 and row.word_token == row.word:
127                 penalty += 0.2 if row.word.isdigit() else 0.3
128         elif row.type == 'H':
129             penalty += sum(0.1 for c in row.word_token if c != ' ' and not c.isdigit())
130             if all(not c.isdigit() for c in row.word_token):
131                 penalty += 0.2 * (len(row.word_token) - 1)
132         elif row.type == 'C':
133             if len(row.word_token) == 1:
134                 penalty += 0.3
135
136         if row.info is None:
137             lookup_word = row.word
138         else:
139             lookup_word = row.info.get('lookup', row.word)
140         if lookup_word:
141             lookup_word = lookup_word.split('@', 1)[0]
142         else:
143             lookup_word = row.word_token
144
145         return ICUToken(penalty=penalty, token=row.word_id, count=max(1, count),
146                         lookup_word=lookup_word,
147                         word_token=row.word_token, info=row.info,
148                         addr_count=max(1, addr_count))
149
150
151 class ICUQueryAnalyzer(AbstractQueryAnalyzer):
152     """ Converter for query strings into a tokenized query
153         using the tokens created by a ICU tokenizer.
154     """
155     def __init__(self, conn: SearchConnection) -> None:
156         self.conn = conn
157
158     async def setup(self) -> None:
159         """ Set up static data structures needed for the analysis.
160         """
161         async def _make_normalizer() -> Any:
162             rules = await self.conn.get_property('tokenizer_import_normalisation')
163             return Transliterator.createFromRules("normalization", rules)
164
165         self.normalizer = await self.conn.get_cached_value('ICUTOK', 'normalizer',
166                                                            _make_normalizer)
167
168         async def _make_transliterator() -> Any:
169             rules = await self.conn.get_property('tokenizer_import_transliteration')
170             return Transliterator.createFromRules("transliteration", rules)
171
172         self.transliterator = await self.conn.get_cached_value('ICUTOK', 'transliterator',
173                                                                _make_transliterator)
174
175         await self._setup_preprocessing()
176
177         if 'word' not in self.conn.t.meta.tables:
178             sa.Table('word', self.conn.t.meta,
179                      sa.Column('word_id', sa.Integer),
180                      sa.Column('word_token', sa.Text, nullable=False),
181                      sa.Column('type', sa.Text, nullable=False),
182                      sa.Column('word', sa.Text),
183                      sa.Column('info', Json))
184
185     async def _setup_preprocessing(self) -> None:
186         """ Load the rules for preprocessing and set up the handlers.
187         """
188
189         rules = self.conn.config.load_sub_configuration('icu_tokenizer.yaml',
190                                                         config='TOKENIZER_CONFIG')
191         preprocessing_rules = rules.get('query-preprocessing', [])
192
193         self.preprocessors = []
194
195         for func in preprocessing_rules:
196             if 'step' not in func:
197                 raise UsageError("Preprocessing rule is missing the 'step' attribute.")
198             if not isinstance(func['step'], str):
199                 raise UsageError("'step' attribute must be a simple string.")
200
201             module = self.conn.config.load_plugin_module(
202                         func['step'], 'nominatim_api.query_preprocessing')
203             self.preprocessors.append(
204                 module.create(QueryConfig(func).set_normalizer(self.normalizer)))
205
206     async def analyze_query(self, phrases: List[qmod.Phrase]) -> qmod.QueryStruct:
207         """ Analyze the given list of phrases and return the
208             tokenized query.
209         """
210         log().section('Analyze query (using ICU tokenizer)')
211         for func in self.preprocessors:
212             phrases = func(phrases)
213         query = qmod.QueryStruct(phrases)
214
215         log().var_dump('Normalized query', query.source)
216         if not query.source:
217             return query
218
219         parts, words = self.split_query(query)
220         log().var_dump('Transliterated query', lambda: _dump_transliterated(query, parts))
221
222         for row in await self.lookup_in_db(list(words.keys())):
223             for trange in words[row.word_token]:
224                 token = ICUToken.from_db_row(row, trange.penalty or 0.0)
225                 if row.type == 'S':
226                     if row.info['op'] in ('in', 'near'):
227                         if trange.start == 0:
228                             query.add_token(trange, qmod.TokenType.NEAR_ITEM, token)
229                     else:
230                         if trange.start == 0 and trange.end == query.num_token_slots():
231                             query.add_token(trange, qmod.TokenType.NEAR_ITEM, token)
232                         else:
233                             query.add_token(trange, qmod.TokenType.QUALIFIER, token)
234                 else:
235                     query.add_token(trange, DB_TO_TOKEN_TYPE[row.type], token)
236
237         self.add_extra_tokens(query, parts)
238         self.rerank_tokens(query, parts)
239
240         log().table_dump('Word tokens', _dump_word_tokens(query))
241
242         return query
243
244     def normalize_text(self, text: str) -> str:
245         """ Bring the given text into a normalized form. That is the
246             standardized form search will work with. All information removed
247             at this stage is inevitably lost.
248         """
249         return cast(str, self.normalizer.transliterate(text)).strip('-: ')
250
251     def split_query(self, query: qmod.QueryStruct) -> Tuple[QueryParts, WordDict]:
252         """ Transliterate the phrases and split them into tokens.
253
254             Returns the list of transliterated tokens together with their
255             normalized form and a dictionary of words for lookup together
256             with their position.
257         """
258         parts: QueryParts = []
259         phrase_start = 0
260         words: WordDict = defaultdict(list)
261         for phrase in query.source:
262             query.nodes[-1].ptype = phrase.ptype
263             phrase_split = re.split('([ :-])', phrase.text)
264             # The zip construct will give us the pairs of word/break from
265             # the regular expression split. As the split array ends on the
266             # final word, we simply use the fillvalue to even out the list and
267             # add the phrase break at the end.
268             for word, breakchar in zip_longest(*[iter(phrase_split)]*2, fillvalue=','):
269                 if not word:
270                     continue
271                 trans = self.transliterator.transliterate(word)
272                 if trans:
273                     for term in trans.split(' '):
274                         if term:
275                             parts.append(QueryPart(term, word,
276                                                    PENALTY_IN_TOKEN_BREAK[qmod.BreakType.TOKEN]))
277                             query.add_node(qmod.BreakType.TOKEN, phrase.ptype)
278                     query.nodes[-1].btype = qmod.BreakType(breakchar)
279                     parts[-1].penalty = PENALTY_IN_TOKEN_BREAK[qmod.BreakType(breakchar)]
280
281             extract_words(parts, phrase_start, words)
282
283             phrase_start = len(parts)
284         query.nodes[-1].btype = qmod.BreakType.END
285
286         return parts, words
287
288     async def lookup_in_db(self, words: List[str]) -> 'sa.Result[Any]':
289         """ Return the token information from the database for the
290             given word tokens.
291         """
292         t = self.conn.t.meta.tables['word']
293         return await self.conn.execute(t.select().where(t.c.word_token.in_(words)))
294
295     def add_extra_tokens(self, query: qmod.QueryStruct, parts: QueryParts) -> None:
296         """ Add tokens to query that are not saved in the database.
297         """
298         for part, node, i in zip(parts, query.nodes, range(1000)):
299             if len(part.token) <= 4 and part.token.isdigit()\
300                and not node.has_tokens(i+1, qmod.TokenType.HOUSENUMBER):
301                 query.add_token(qmod.TokenRange(i, i+1), qmod.TokenType.HOUSENUMBER,
302                                 ICUToken(penalty=0.5, token=0,
303                                          count=1, addr_count=1, lookup_word=part.token,
304                                          word_token=part.token, info=None))
305
306     def rerank_tokens(self, query: qmod.QueryStruct, parts: QueryParts) -> None:
307         """ Add penalties to tokens that depend on presence of other token.
308         """
309         for i, node, tlist in query.iter_token_lists():
310             if tlist.ttype == qmod.TokenType.POSTCODE:
311                 for repl in node.starting:
312                     if repl.end == tlist.end and repl.ttype != qmod.TokenType.POSTCODE \
313                        and (repl.ttype != qmod.TokenType.HOUSENUMBER
314                             or len(tlist.tokens[0].lookup_word) > 4):
315                         repl.add_penalty(0.39)
316             elif (tlist.ttype == qmod.TokenType.HOUSENUMBER
317                   and len(tlist.tokens[0].lookup_word) <= 3):
318                 if any(c.isdigit() for c in tlist.tokens[0].lookup_word):
319                     for repl in node.starting:
320                         if repl.end == tlist.end and repl.ttype != qmod.TokenType.HOUSENUMBER:
321                             repl.add_penalty(0.5 - tlist.tokens[0].penalty)
322             elif tlist.ttype not in (qmod.TokenType.COUNTRY, qmod.TokenType.PARTIAL):
323                 norm = parts[i].normalized
324                 for j in range(i + 1, tlist.end):
325                     if node.btype != qmod.BreakType.TOKEN:
326                         norm += '  ' + parts[j].normalized
327                 for token in tlist.tokens:
328                     cast(ICUToken, token).rematch(norm)
329
330
331 def _dump_transliterated(query: qmod.QueryStruct, parts: QueryParts) -> str:
332     out = query.nodes[0].btype.value
333     for node, part in zip(query.nodes[1:], parts):
334         out += part.token + node.btype.value
335     return out
336
337
338 def _dump_word_tokens(query: qmod.QueryStruct) -> Iterator[List[Any]]:
339     yield ['type', 'token', 'word_token', 'lookup_word', 'penalty', 'count', 'info']
340     for node in query.nodes:
341         for tlist in node.starting:
342             for token in tlist.tokens:
343                 t = cast(ICUToken, token)
344                 yield [tlist.ttype.name, t.token, t.word_token or '',
345                        t.lookup_word or '', t.penalty, t.count, t.info]
346
347
348 async def create_query_analyzer(conn: SearchConnection) -> AbstractQueryAnalyzer:
349     """ Create and set up a new query analyzer for a database based
350         on the ICU tokenizer.
351     """
352     out = ICUQueryAnalyzer(conn)
353     await out.setup()
354
355     return out