]> git.openstreetmap.org Git - nominatim.git/blob - src/nominatim_api/search/icu_tokenizer.py
add inner word break penalty
[nominatim.git] / src / nominatim_api / search / icu_tokenizer.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2024 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Implementation of query analysis for the ICU tokenizer.
9 """
10 from typing import Tuple, Dict, List, Optional, Iterator, Any, cast
11 from collections import defaultdict
12 import dataclasses
13 import difflib
14 import re
15 from itertools import zip_longest
16
17 from icu import Transliterator
18
19 import sqlalchemy as sa
20
21 from ..errors import UsageError
22 from ..typing import SaRow
23 from ..sql.sqlalchemy_types import Json
24 from ..connection import SearchConnection
25 from ..logging import log
26 from . import query as qmod
27 from ..query_preprocessing.config import QueryConfig
28 from .query_analyzer_factory import AbstractQueryAnalyzer
29
30
31 DB_TO_TOKEN_TYPE = {
32     'W': qmod.TokenType.WORD,
33     'w': qmod.TokenType.PARTIAL,
34     'H': qmod.TokenType.HOUSENUMBER,
35     'P': qmod.TokenType.POSTCODE,
36     'C': qmod.TokenType.COUNTRY
37 }
38
39 PENALTY_IN_TOKEN_BREAK = {
40      qmod.BreakType.START: 0.5,
41      qmod.BreakType.END: 0.5,
42      qmod.BreakType.PHRASE: 0.5,
43      qmod.BreakType.SOFT_PHRASE: 0.5,
44      qmod.BreakType.WORD: 0.1,
45      qmod.BreakType.PART: 0.0,
46      qmod.BreakType.TOKEN: 0.0
47 }
48
49
50 @dataclasses.dataclass
51 class QueryPart:
52     """ Normalized and transliterated form of a single term in the query.
53         When the term came out of a split during the transliteration,
54         the normalized string is the full word before transliteration.
55         The word number keeps track of the word before transliteration
56         and can be used to identify partial transliterated terms.
57         Penalty is the break penalty for the break following the token.
58     """
59     token: str
60     normalized: str
61     word_number: int
62     penalty: float
63
64
65 QueryParts = List[QueryPart]
66 WordDict = Dict[str, List[qmod.TokenRange]]
67
68
69 def yield_words(terms: List[QueryPart], start: int) -> Iterator[Tuple[str, qmod.TokenRange]]:
70     """ Return all combinations of words in the terms list after the
71         given position.
72     """
73     total = len(terms)
74     for first in range(start, total):
75         word = terms[first].token
76         penalty = PENALTY_IN_TOKEN_BREAK[qmod.BreakType.WORD]
77         yield word, qmod.TokenRange(first, first + 1, penalty=penalty)
78         for last in range(first + 1, min(first + 20, total)):
79             word = ' '.join((word, terms[last].token))
80             penalty += terms[last - 1].penalty
81             yield word, qmod.TokenRange(first, last + 1, penalty=penalty)
82
83
84 @dataclasses.dataclass
85 class ICUToken(qmod.Token):
86     """ Specialised token for ICU tokenizer.
87     """
88     word_token: str
89     info: Optional[Dict[str, Any]]
90
91     def get_category(self) -> Tuple[str, str]:
92         assert self.info
93         return self.info.get('class', ''), self.info.get('type', '')
94
95     def rematch(self, norm: str) -> None:
96         """ Check how well the token matches the given normalized string
97             and add a penalty, if necessary.
98         """
99         if not self.lookup_word:
100             return
101
102         seq = difflib.SequenceMatcher(a=self.lookup_word, b=norm)
103         distance = 0
104         for tag, afrom, ato, bfrom, bto in seq.get_opcodes():
105             if tag in ('delete', 'insert') and (afrom == 0 or ato == len(self.lookup_word)):
106                 distance += 1
107             elif tag == 'replace':
108                 distance += max((ato-afrom), (bto-bfrom))
109             elif tag != 'equal':
110                 distance += abs((ato-afrom) - (bto-bfrom))
111         self.penalty += (distance/len(self.lookup_word))
112
113     @staticmethod
114     def from_db_row(row: SaRow, base_penalty: float = 0.0) -> 'ICUToken':
115         """ Create a ICUToken from the row of the word table.
116         """
117         count = 1 if row.info is None else row.info.get('count', 1)
118         addr_count = 1 if row.info is None else row.info.get('addr_count', 1)
119
120         penalty = base_penalty
121         if row.type == 'w':
122             penalty += 0.3
123         elif row.type == 'W':
124             if len(row.word_token) == 1 and row.word_token == row.word:
125                 penalty += 0.2 if row.word.isdigit() else 0.3
126         elif row.type == 'H':
127             penalty += sum(0.1 for c in row.word_token if c != ' ' and not c.isdigit())
128             if all(not c.isdigit() for c in row.word_token):
129                 penalty += 0.2 * (len(row.word_token) - 1)
130         elif row.type == 'C':
131             if len(row.word_token) == 1:
132                 penalty += 0.3
133
134         if row.info is None:
135             lookup_word = row.word
136         else:
137             lookup_word = row.info.get('lookup', row.word)
138         if lookup_word:
139             lookup_word = lookup_word.split('@', 1)[0]
140         else:
141             lookup_word = row.word_token
142
143         return ICUToken(penalty=penalty, token=row.word_id, count=max(1, count),
144                         lookup_word=lookup_word,
145                         word_token=row.word_token, info=row.info,
146                         addr_count=max(1, addr_count))
147
148
149 class ICUQueryAnalyzer(AbstractQueryAnalyzer):
150     """ Converter for query strings into a tokenized query
151         using the tokens created by a ICU tokenizer.
152     """
153     def __init__(self, conn: SearchConnection) -> None:
154         self.conn = conn
155
156     async def setup(self) -> None:
157         """ Set up static data structures needed for the analysis.
158         """
159         async def _make_normalizer() -> Any:
160             rules = await self.conn.get_property('tokenizer_import_normalisation')
161             return Transliterator.createFromRules("normalization", rules)
162
163         self.normalizer = await self.conn.get_cached_value('ICUTOK', 'normalizer',
164                                                            _make_normalizer)
165
166         async def _make_transliterator() -> Any:
167             rules = await self.conn.get_property('tokenizer_import_transliteration')
168             return Transliterator.createFromRules("transliteration", rules)
169
170         self.transliterator = await self.conn.get_cached_value('ICUTOK', 'transliterator',
171                                                                _make_transliterator)
172
173         await self._setup_preprocessing()
174
175         if 'word' not in self.conn.t.meta.tables:
176             sa.Table('word', self.conn.t.meta,
177                      sa.Column('word_id', sa.Integer),
178                      sa.Column('word_token', sa.Text, nullable=False),
179                      sa.Column('type', sa.Text, nullable=False),
180                      sa.Column('word', sa.Text),
181                      sa.Column('info', Json))
182
183     async def _setup_preprocessing(self) -> None:
184         """ Load the rules for preprocessing and set up the handlers.
185         """
186
187         rules = self.conn.config.load_sub_configuration('icu_tokenizer.yaml',
188                                                         config='TOKENIZER_CONFIG')
189         preprocessing_rules = rules.get('query-preprocessing', [])
190
191         self.preprocessors = []
192
193         for func in preprocessing_rules:
194             if 'step' not in func:
195                 raise UsageError("Preprocessing rule is missing the 'step' attribute.")
196             if not isinstance(func['step'], str):
197                 raise UsageError("'step' attribute must be a simple string.")
198
199             module = self.conn.config.load_plugin_module(
200                         func['step'], 'nominatim_api.query_preprocessing')
201             self.preprocessors.append(
202                 module.create(QueryConfig(func).set_normalizer(self.normalizer)))
203
204     async def analyze_query(self, phrases: List[qmod.Phrase]) -> qmod.QueryStruct:
205         """ Analyze the given list of phrases and return the
206             tokenized query.
207         """
208         log().section('Analyze query (using ICU tokenizer)')
209         for func in self.preprocessors:
210             phrases = func(phrases)
211         query = qmod.QueryStruct(phrases)
212
213         log().var_dump('Normalized query', query.source)
214         if not query.source:
215             return query
216
217         parts, words = self.split_query(query)
218         log().var_dump('Transliterated query', lambda: _dump_transliterated(query, parts))
219
220         for row in await self.lookup_in_db(list(words.keys())):
221             for trange in words[row.word_token]:
222                 token = ICUToken.from_db_row(row, trange.penalty or 0.0)
223                 if row.type == 'S':
224                     if row.info['op'] in ('in', 'near'):
225                         if trange.start == 0:
226                             query.add_token(trange, qmod.TokenType.NEAR_ITEM, token)
227                     else:
228                         if trange.start == 0 and trange.end == query.num_token_slots():
229                             query.add_token(trange, qmod.TokenType.NEAR_ITEM, token)
230                         else:
231                             query.add_token(trange, qmod.TokenType.QUALIFIER, token)
232                 else:
233                     query.add_token(trange, DB_TO_TOKEN_TYPE[row.type], token)
234
235         self.add_extra_tokens(query, parts)
236         self.rerank_tokens(query, parts)
237
238         log().table_dump('Word tokens', _dump_word_tokens(query))
239
240         return query
241
242     def normalize_text(self, text: str) -> str:
243         """ Bring the given text into a normalized form. That is the
244             standardized form search will work with. All information removed
245             at this stage is inevitably lost.
246         """
247         return cast(str, self.normalizer.transliterate(text))
248
249     def split_query(self, query: qmod.QueryStruct) -> Tuple[QueryParts, WordDict]:
250         """ Transliterate the phrases and split them into tokens.
251
252             Returns the list of transliterated tokens together with their
253             normalized form and a dictionary of words for lookup together
254             with their position.
255         """
256         parts: QueryParts = []
257         phrase_start = 0
258         words = defaultdict(list)
259         wordnr = 0
260         for phrase in query.source:
261             query.nodes[-1].ptype = phrase.ptype
262             phrase_split = re.split('([ :-])', phrase.text)
263             # The zip construct will give us the pairs of word/break from
264             # the regular expression split. As the split array ends on the
265             # final word, we simply use the fillvalue to even out the list and
266             # add the phrase break at the end.
267             for word, breakchar in zip_longest(*[iter(phrase_split)]*2, fillvalue=','):
268                 if not word:
269                     continue
270                 trans = self.transliterator.transliterate(word)
271                 if trans:
272                     for term in trans.split(' '):
273                         if term:
274                             parts.append(QueryPart(term, word, wordnr,
275                                                    PENALTY_IN_TOKEN_BREAK[qmod.BreakType.TOKEN]))
276                             query.add_node(qmod.BreakType.TOKEN, phrase.ptype)
277                     query.nodes[-1].btype = qmod.BreakType(breakchar)
278                     parts[-1].penalty = PENALTY_IN_TOKEN_BREAK[qmod.BreakType(breakchar)]
279                 wordnr += 1
280
281             for word, wrange in yield_words(parts, phrase_start):
282                 words[word].append(wrange)
283
284             phrase_start = len(parts)
285         query.nodes[-1].btype = qmod.BreakType.END
286
287         return parts, words
288
289     async def lookup_in_db(self, words: List[str]) -> 'sa.Result[Any]':
290         """ Return the token information from the database for the
291             given word tokens.
292         """
293         t = self.conn.t.meta.tables['word']
294         return await self.conn.execute(t.select().where(t.c.word_token.in_(words)))
295
296     def add_extra_tokens(self, query: qmod.QueryStruct, parts: QueryParts) -> None:
297         """ Add tokens to query that are not saved in the database.
298         """
299         for part, node, i in zip(parts, query.nodes, range(1000)):
300             if len(part.token) <= 4 and part.token.isdigit()\
301                and not node.has_tokens(i+1, qmod.TokenType.HOUSENUMBER):
302                 query.add_token(qmod.TokenRange(i, i+1), qmod.TokenType.HOUSENUMBER,
303                                 ICUToken(penalty=0.5, token=0,
304                                          count=1, addr_count=1, lookup_word=part.token,
305                                          word_token=part.token, info=None))
306
307     def rerank_tokens(self, query: qmod.QueryStruct, parts: QueryParts) -> None:
308         """ Add penalties to tokens that depend on presence of other token.
309         """
310         for i, node, tlist in query.iter_token_lists():
311             if tlist.ttype == qmod.TokenType.POSTCODE:
312                 for repl in node.starting:
313                     if repl.end == tlist.end and repl.ttype != qmod.TokenType.POSTCODE \
314                        and (repl.ttype != qmod.TokenType.HOUSENUMBER
315                             or len(tlist.tokens[0].lookup_word) > 4):
316                         repl.add_penalty(0.39)
317             elif (tlist.ttype == qmod.TokenType.HOUSENUMBER
318                   and len(tlist.tokens[0].lookup_word) <= 3):
319                 if any(c.isdigit() for c in tlist.tokens[0].lookup_word):
320                     for repl in node.starting:
321                         if repl.end == tlist.end and repl.ttype != qmod.TokenType.HOUSENUMBER:
322                             repl.add_penalty(0.5 - tlist.tokens[0].penalty)
323             elif tlist.ttype not in (qmod.TokenType.COUNTRY, qmod.TokenType.PARTIAL):
324                 norm = parts[i].normalized
325                 for j in range(i + 1, tlist.end):
326                     if parts[j - 1].word_number != parts[j].word_number:
327                         norm += '  ' + parts[j].normalized
328                 for token in tlist.tokens:
329                     cast(ICUToken, token).rematch(norm)
330
331
332 def _dump_transliterated(query: qmod.QueryStruct, parts: QueryParts) -> str:
333     out = query.nodes[0].btype.value
334     for node, part in zip(query.nodes[1:], parts):
335         out += part.token + node.btype.value
336     return out
337
338
339 def _dump_word_tokens(query: qmod.QueryStruct) -> Iterator[List[Any]]:
340     yield ['type', 'token', 'word_token', 'lookup_word', 'penalty', 'count', 'info']
341     for node in query.nodes:
342         for tlist in node.starting:
343             for token in tlist.tokens:
344                 t = cast(ICUToken, token)
345                 yield [tlist.ttype.name, t.token, t.word_token or '',
346                        t.lookup_word or '', t.penalty, t.count, t.info]
347
348
349 async def create_query_analyzer(conn: SearchConnection) -> AbstractQueryAnalyzer:
350     """ Create and set up a new query analyzer for a database based
351         on the ICU tokenizer.
352     """
353     out = ICUQueryAnalyzer(conn)
354     await out.setup()
355
356     return out