]> git.openstreetmap.org Git - nominatim.git/blob - src/nominatim_api/search/icu_tokenizer.py
remove word_number counting for phrases
[nominatim.git] / src / nominatim_api / search / icu_tokenizer.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2024 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Implementation of query analysis for the ICU tokenizer.
9 """
10 from typing import Tuple, Dict, List, Optional, Iterator, Any, cast
11 from collections import defaultdict
12 import dataclasses
13 import difflib
14 import re
15 from itertools import zip_longest
16
17 from icu import Transliterator
18
19 import sqlalchemy as sa
20
21 from ..errors import UsageError
22 from ..typing import SaRow
23 from ..sql.sqlalchemy_types import Json
24 from ..connection import SearchConnection
25 from ..logging import log
26 from . import query as qmod
27 from ..query_preprocessing.config import QueryConfig
28 from .query_analyzer_factory import AbstractQueryAnalyzer
29
30
31 DB_TO_TOKEN_TYPE = {
32     'W': qmod.TokenType.WORD,
33     'w': qmod.TokenType.PARTIAL,
34     'H': qmod.TokenType.HOUSENUMBER,
35     'P': qmod.TokenType.POSTCODE,
36     'C': qmod.TokenType.COUNTRY
37 }
38
39 PENALTY_IN_TOKEN_BREAK = {
40      qmod.BreakType.START: 0.5,
41      qmod.BreakType.END: 0.5,
42      qmod.BreakType.PHRASE: 0.5,
43      qmod.BreakType.SOFT_PHRASE: 0.5,
44      qmod.BreakType.WORD: 0.1,
45      qmod.BreakType.PART: 0.0,
46      qmod.BreakType.TOKEN: 0.0
47 }
48
49
50 @dataclasses.dataclass
51 class QueryPart:
52     """ Normalized and transliterated form of a single term in the query.
53
54         When the term came out of a split during the transliteration,
55         the normalized string is the full word before transliteration.
56         Check the subsequent break type to figure out if the word is
57         continued.
58
59         Penalty is the break penalty for the break following the token.
60     """
61     token: str
62     normalized: str
63     penalty: float
64
65
66 QueryParts = List[QueryPart]
67 WordDict = Dict[str, List[qmod.TokenRange]]
68
69
70 def yield_words(terms: List[QueryPart], start: int) -> Iterator[Tuple[str, qmod.TokenRange]]:
71     """ Return all combinations of words in the terms list after the
72         given position.
73     """
74     total = len(terms)
75     for first in range(start, total):
76         word = terms[first].token
77         penalty = PENALTY_IN_TOKEN_BREAK[qmod.BreakType.WORD]
78         yield word, qmod.TokenRange(first, first + 1, penalty=penalty)
79         for last in range(first + 1, min(first + 20, total)):
80             word = ' '.join((word, terms[last].token))
81             penalty += terms[last - 1].penalty
82             yield word, qmod.TokenRange(first, last + 1, penalty=penalty)
83
84
85 @dataclasses.dataclass
86 class ICUToken(qmod.Token):
87     """ Specialised token for ICU tokenizer.
88     """
89     word_token: str
90     info: Optional[Dict[str, Any]]
91
92     def get_category(self) -> Tuple[str, str]:
93         assert self.info
94         return self.info.get('class', ''), self.info.get('type', '')
95
96     def rematch(self, norm: str) -> None:
97         """ Check how well the token matches the given normalized string
98             and add a penalty, if necessary.
99         """
100         if not self.lookup_word:
101             return
102
103         seq = difflib.SequenceMatcher(a=self.lookup_word, b=norm)
104         distance = 0
105         for tag, afrom, ato, bfrom, bto in seq.get_opcodes():
106             if tag in ('delete', 'insert') and (afrom == 0 or ato == len(self.lookup_word)):
107                 distance += 1
108             elif tag == 'replace':
109                 distance += max((ato-afrom), (bto-bfrom))
110             elif tag != 'equal':
111                 distance += abs((ato-afrom) - (bto-bfrom))
112         self.penalty += (distance/len(self.lookup_word))
113
114     @staticmethod
115     def from_db_row(row: SaRow, base_penalty: float = 0.0) -> 'ICUToken':
116         """ Create a ICUToken from the row of the word table.
117         """
118         count = 1 if row.info is None else row.info.get('count', 1)
119         addr_count = 1 if row.info is None else row.info.get('addr_count', 1)
120
121         penalty = base_penalty
122         if row.type == 'w':
123             penalty += 0.3
124         elif row.type == 'W':
125             if len(row.word_token) == 1 and row.word_token == row.word:
126                 penalty += 0.2 if row.word.isdigit() else 0.3
127         elif row.type == 'H':
128             penalty += sum(0.1 for c in row.word_token if c != ' ' and not c.isdigit())
129             if all(not c.isdigit() for c in row.word_token):
130                 penalty += 0.2 * (len(row.word_token) - 1)
131         elif row.type == 'C':
132             if len(row.word_token) == 1:
133                 penalty += 0.3
134
135         if row.info is None:
136             lookup_word = row.word
137         else:
138             lookup_word = row.info.get('lookup', row.word)
139         if lookup_word:
140             lookup_word = lookup_word.split('@', 1)[0]
141         else:
142             lookup_word = row.word_token
143
144         return ICUToken(penalty=penalty, token=row.word_id, count=max(1, count),
145                         lookup_word=lookup_word,
146                         word_token=row.word_token, info=row.info,
147                         addr_count=max(1, addr_count))
148
149
150 class ICUQueryAnalyzer(AbstractQueryAnalyzer):
151     """ Converter for query strings into a tokenized query
152         using the tokens created by a ICU tokenizer.
153     """
154     def __init__(self, conn: SearchConnection) -> None:
155         self.conn = conn
156
157     async def setup(self) -> None:
158         """ Set up static data structures needed for the analysis.
159         """
160         async def _make_normalizer() -> Any:
161             rules = await self.conn.get_property('tokenizer_import_normalisation')
162             return Transliterator.createFromRules("normalization", rules)
163
164         self.normalizer = await self.conn.get_cached_value('ICUTOK', 'normalizer',
165                                                            _make_normalizer)
166
167         async def _make_transliterator() -> Any:
168             rules = await self.conn.get_property('tokenizer_import_transliteration')
169             return Transliterator.createFromRules("transliteration", rules)
170
171         self.transliterator = await self.conn.get_cached_value('ICUTOK', 'transliterator',
172                                                                _make_transliterator)
173
174         await self._setup_preprocessing()
175
176         if 'word' not in self.conn.t.meta.tables:
177             sa.Table('word', self.conn.t.meta,
178                      sa.Column('word_id', sa.Integer),
179                      sa.Column('word_token', sa.Text, nullable=False),
180                      sa.Column('type', sa.Text, nullable=False),
181                      sa.Column('word', sa.Text),
182                      sa.Column('info', Json))
183
184     async def _setup_preprocessing(self) -> None:
185         """ Load the rules for preprocessing and set up the handlers.
186         """
187
188         rules = self.conn.config.load_sub_configuration('icu_tokenizer.yaml',
189                                                         config='TOKENIZER_CONFIG')
190         preprocessing_rules = rules.get('query-preprocessing', [])
191
192         self.preprocessors = []
193
194         for func in preprocessing_rules:
195             if 'step' not in func:
196                 raise UsageError("Preprocessing rule is missing the 'step' attribute.")
197             if not isinstance(func['step'], str):
198                 raise UsageError("'step' attribute must be a simple string.")
199
200             module = self.conn.config.load_plugin_module(
201                         func['step'], 'nominatim_api.query_preprocessing')
202             self.preprocessors.append(
203                 module.create(QueryConfig(func).set_normalizer(self.normalizer)))
204
205     async def analyze_query(self, phrases: List[qmod.Phrase]) -> qmod.QueryStruct:
206         """ Analyze the given list of phrases and return the
207             tokenized query.
208         """
209         log().section('Analyze query (using ICU tokenizer)')
210         for func in self.preprocessors:
211             phrases = func(phrases)
212         query = qmod.QueryStruct(phrases)
213
214         log().var_dump('Normalized query', query.source)
215         if not query.source:
216             return query
217
218         parts, words = self.split_query(query)
219         log().var_dump('Transliterated query', lambda: _dump_transliterated(query, parts))
220
221         for row in await self.lookup_in_db(list(words.keys())):
222             for trange in words[row.word_token]:
223                 token = ICUToken.from_db_row(row, trange.penalty or 0.0)
224                 if row.type == 'S':
225                     if row.info['op'] in ('in', 'near'):
226                         if trange.start == 0:
227                             query.add_token(trange, qmod.TokenType.NEAR_ITEM, token)
228                     else:
229                         if trange.start == 0 and trange.end == query.num_token_slots():
230                             query.add_token(trange, qmod.TokenType.NEAR_ITEM, token)
231                         else:
232                             query.add_token(trange, qmod.TokenType.QUALIFIER, token)
233                 else:
234                     query.add_token(trange, DB_TO_TOKEN_TYPE[row.type], token)
235
236         self.add_extra_tokens(query, parts)
237         self.rerank_tokens(query, parts)
238
239         log().table_dump('Word tokens', _dump_word_tokens(query))
240
241         return query
242
243     def normalize_text(self, text: str) -> str:
244         """ Bring the given text into a normalized form. That is the
245             standardized form search will work with. All information removed
246             at this stage is inevitably lost.
247         """
248         return cast(str, self.normalizer.transliterate(text)).strip('-: ')
249
250     def split_query(self, query: qmod.QueryStruct) -> Tuple[QueryParts, WordDict]:
251         """ Transliterate the phrases and split them into tokens.
252
253             Returns the list of transliterated tokens together with their
254             normalized form and a dictionary of words for lookup together
255             with their position.
256         """
257         parts: QueryParts = []
258         phrase_start = 0
259         words = defaultdict(list)
260         for phrase in query.source:
261             query.nodes[-1].ptype = phrase.ptype
262             phrase_split = re.split('([ :-])', phrase.text)
263             # The zip construct will give us the pairs of word/break from
264             # the regular expression split. As the split array ends on the
265             # final word, we simply use the fillvalue to even out the list and
266             # add the phrase break at the end.
267             for word, breakchar in zip_longest(*[iter(phrase_split)]*2, fillvalue=','):
268                 if not word:
269                     continue
270                 trans = self.transliterator.transliterate(word)
271                 if trans:
272                     for term in trans.split(' '):
273                         if term:
274                             parts.append(QueryPart(term, word,
275                                                    PENALTY_IN_TOKEN_BREAK[qmod.BreakType.TOKEN]))
276                             query.add_node(qmod.BreakType.TOKEN, phrase.ptype)
277                     query.nodes[-1].btype = qmod.BreakType(breakchar)
278                     parts[-1].penalty = PENALTY_IN_TOKEN_BREAK[qmod.BreakType(breakchar)]
279
280             for word, wrange in yield_words(parts, phrase_start):
281                 words[word].append(wrange)
282
283             phrase_start = len(parts)
284         query.nodes[-1].btype = qmod.BreakType.END
285
286         return parts, words
287
288     async def lookup_in_db(self, words: List[str]) -> 'sa.Result[Any]':
289         """ Return the token information from the database for the
290             given word tokens.
291         """
292         t = self.conn.t.meta.tables['word']
293         return await self.conn.execute(t.select().where(t.c.word_token.in_(words)))
294
295     def add_extra_tokens(self, query: qmod.QueryStruct, parts: QueryParts) -> None:
296         """ Add tokens to query that are not saved in the database.
297         """
298         for part, node, i in zip(parts, query.nodes, range(1000)):
299             if len(part.token) <= 4 and part.token.isdigit()\
300                and not node.has_tokens(i+1, qmod.TokenType.HOUSENUMBER):
301                 query.add_token(qmod.TokenRange(i, i+1), qmod.TokenType.HOUSENUMBER,
302                                 ICUToken(penalty=0.5, token=0,
303                                          count=1, addr_count=1, lookup_word=part.token,
304                                          word_token=part.token, info=None))
305
306     def rerank_tokens(self, query: qmod.QueryStruct, parts: QueryParts) -> None:
307         """ Add penalties to tokens that depend on presence of other token.
308         """
309         for i, node, tlist in query.iter_token_lists():
310             if tlist.ttype == qmod.TokenType.POSTCODE:
311                 for repl in node.starting:
312                     if repl.end == tlist.end and repl.ttype != qmod.TokenType.POSTCODE \
313                        and (repl.ttype != qmod.TokenType.HOUSENUMBER
314                             or len(tlist.tokens[0].lookup_word) > 4):
315                         repl.add_penalty(0.39)
316             elif (tlist.ttype == qmod.TokenType.HOUSENUMBER
317                   and len(tlist.tokens[0].lookup_word) <= 3):
318                 if any(c.isdigit() for c in tlist.tokens[0].lookup_word):
319                     for repl in node.starting:
320                         if repl.end == tlist.end and repl.ttype != qmod.TokenType.HOUSENUMBER:
321                             repl.add_penalty(0.5 - tlist.tokens[0].penalty)
322             elif tlist.ttype not in (qmod.TokenType.COUNTRY, qmod.TokenType.PARTIAL):
323                 norm = parts[i].normalized
324                 for j in range(i + 1, tlist.end):
325                     if node.btype != qmod.BreakType.TOKEN:
326                         norm += '  ' + parts[j].normalized
327                 for token in tlist.tokens:
328                     cast(ICUToken, token).rematch(norm)
329
330
331 def _dump_transliterated(query: qmod.QueryStruct, parts: QueryParts) -> str:
332     out = query.nodes[0].btype.value
333     for node, part in zip(query.nodes[1:], parts):
334         out += part.token + node.btype.value
335     return out
336
337
338 def _dump_word_tokens(query: qmod.QueryStruct) -> Iterator[List[Any]]:
339     yield ['type', 'token', 'word_token', 'lookup_word', 'penalty', 'count', 'info']
340     for node in query.nodes:
341         for tlist in node.starting:
342             for token in tlist.tokens:
343                 t = cast(ICUToken, token)
344                 yield [tlist.ttype.name, t.token, t.word_token or '',
345                        t.lookup_word or '', t.penalty, t.count, t.info]
346
347
348 async def create_query_analyzer(conn: SearchConnection) -> AbstractQueryAnalyzer:
349     """ Create and set up a new query analyzer for a database based
350         on the ICU tokenizer.
351     """
352     out = ICUQueryAnalyzer(conn)
353     await out.setup()
354
355     return out