1 # SPDX-License-Identifier: GPL-3.0-or-later
3 # This file is part of Nominatim. (https://nominatim.org)
5 # Copyright (C) 2024 by the Nominatim developer community.
6 # For a full list of authors see the git log.
8 Implementation of query analysis for the ICU tokenizer.
10 from typing import Tuple, Dict, List, Optional, Iterator, Any, cast
11 from collections import defaultdict
15 from itertools import zip_longest
17 from icu import Transliterator
19 import sqlalchemy as sa
21 from ..errors import UsageError
22 from ..typing import SaRow
23 from ..sql.sqlalchemy_types import Json
24 from ..connection import SearchConnection
25 from ..logging import log
26 from . import query as qmod
27 from ..query_preprocessing.config import QueryConfig
28 from .query_analyzer_factory import AbstractQueryAnalyzer
32 'W': qmod.TokenType.WORD,
33 'w': qmod.TokenType.PARTIAL,
34 'H': qmod.TokenType.HOUSENUMBER,
35 'P': qmod.TokenType.POSTCODE,
36 'C': qmod.TokenType.COUNTRY
39 PENALTY_IN_TOKEN_BREAK = {
40 qmod.BreakType.START: 0.5,
41 qmod.BreakType.END: 0.5,
42 qmod.BreakType.PHRASE: 0.5,
43 qmod.BreakType.SOFT_PHRASE: 0.5,
44 qmod.BreakType.WORD: 0.1,
45 qmod.BreakType.PART: 0.0,
46 qmod.BreakType.TOKEN: 0.0
50 @dataclasses.dataclass
52 """ Normalized and transliterated form of a single term in the query.
54 When the term came out of a split during the transliteration,
55 the normalized string is the full word before transliteration.
56 Check the subsequent break type to figure out if the word is
59 Penalty is the break penalty for the break following the token.
66 QueryParts = List[QueryPart]
67 WordDict = Dict[str, List[qmod.TokenRange]]
70 def yield_words(terms: List[QueryPart], start: int) -> Iterator[Tuple[str, qmod.TokenRange]]:
71 """ Return all combinations of words in the terms list after the
75 for first in range(start, total):
76 word = terms[first].token
77 penalty = PENALTY_IN_TOKEN_BREAK[qmod.BreakType.WORD]
78 yield word, qmod.TokenRange(first, first + 1, penalty=penalty)
79 for last in range(first + 1, min(first + 20, total)):
80 word = ' '.join((word, terms[last].token))
81 penalty += terms[last - 1].penalty
82 yield word, qmod.TokenRange(first, last + 1, penalty=penalty)
85 @dataclasses.dataclass
86 class ICUToken(qmod.Token):
87 """ Specialised token for ICU tokenizer.
90 info: Optional[Dict[str, Any]]
92 def get_category(self) -> Tuple[str, str]:
94 return self.info.get('class', ''), self.info.get('type', '')
96 def rematch(self, norm: str) -> None:
97 """ Check how well the token matches the given normalized string
98 and add a penalty, if necessary.
100 if not self.lookup_word:
103 seq = difflib.SequenceMatcher(a=self.lookup_word, b=norm)
105 for tag, afrom, ato, bfrom, bto in seq.get_opcodes():
106 if tag in ('delete', 'insert') and (afrom == 0 or ato == len(self.lookup_word)):
108 elif tag == 'replace':
109 distance += max((ato-afrom), (bto-bfrom))
111 distance += abs((ato-afrom) - (bto-bfrom))
112 self.penalty += (distance/len(self.lookup_word))
115 def from_db_row(row: SaRow, base_penalty: float = 0.0) -> 'ICUToken':
116 """ Create a ICUToken from the row of the word table.
118 count = 1 if row.info is None else row.info.get('count', 1)
119 addr_count = 1 if row.info is None else row.info.get('addr_count', 1)
121 penalty = base_penalty
124 elif row.type == 'W':
125 if len(row.word_token) == 1 and row.word_token == row.word:
126 penalty += 0.2 if row.word.isdigit() else 0.3
127 elif row.type == 'H':
128 penalty += sum(0.1 for c in row.word_token if c != ' ' and not c.isdigit())
129 if all(not c.isdigit() for c in row.word_token):
130 penalty += 0.2 * (len(row.word_token) - 1)
131 elif row.type == 'C':
132 if len(row.word_token) == 1:
136 lookup_word = row.word
138 lookup_word = row.info.get('lookup', row.word)
140 lookup_word = lookup_word.split('@', 1)[0]
142 lookup_word = row.word_token
144 return ICUToken(penalty=penalty, token=row.word_id, count=max(1, count),
145 lookup_word=lookup_word,
146 word_token=row.word_token, info=row.info,
147 addr_count=max(1, addr_count))
150 class ICUQueryAnalyzer(AbstractQueryAnalyzer):
151 """ Converter for query strings into a tokenized query
152 using the tokens created by a ICU tokenizer.
154 def __init__(self, conn: SearchConnection) -> None:
157 async def setup(self) -> None:
158 """ Set up static data structures needed for the analysis.
160 async def _make_normalizer() -> Any:
161 rules = await self.conn.get_property('tokenizer_import_normalisation')
162 return Transliterator.createFromRules("normalization", rules)
164 self.normalizer = await self.conn.get_cached_value('ICUTOK', 'normalizer',
167 async def _make_transliterator() -> Any:
168 rules = await self.conn.get_property('tokenizer_import_transliteration')
169 return Transliterator.createFromRules("transliteration", rules)
171 self.transliterator = await self.conn.get_cached_value('ICUTOK', 'transliterator',
172 _make_transliterator)
174 await self._setup_preprocessing()
176 if 'word' not in self.conn.t.meta.tables:
177 sa.Table('word', self.conn.t.meta,
178 sa.Column('word_id', sa.Integer),
179 sa.Column('word_token', sa.Text, nullable=False),
180 sa.Column('type', sa.Text, nullable=False),
181 sa.Column('word', sa.Text),
182 sa.Column('info', Json))
184 async def _setup_preprocessing(self) -> None:
185 """ Load the rules for preprocessing and set up the handlers.
188 rules = self.conn.config.load_sub_configuration('icu_tokenizer.yaml',
189 config='TOKENIZER_CONFIG')
190 preprocessing_rules = rules.get('query-preprocessing', [])
192 self.preprocessors = []
194 for func in preprocessing_rules:
195 if 'step' not in func:
196 raise UsageError("Preprocessing rule is missing the 'step' attribute.")
197 if not isinstance(func['step'], str):
198 raise UsageError("'step' attribute must be a simple string.")
200 module = self.conn.config.load_plugin_module(
201 func['step'], 'nominatim_api.query_preprocessing')
202 self.preprocessors.append(
203 module.create(QueryConfig(func).set_normalizer(self.normalizer)))
205 async def analyze_query(self, phrases: List[qmod.Phrase]) -> qmod.QueryStruct:
206 """ Analyze the given list of phrases and return the
209 log().section('Analyze query (using ICU tokenizer)')
210 for func in self.preprocessors:
211 phrases = func(phrases)
212 query = qmod.QueryStruct(phrases)
214 log().var_dump('Normalized query', query.source)
218 parts, words = self.split_query(query)
219 log().var_dump('Transliterated query', lambda: _dump_transliterated(query, parts))
221 for row in await self.lookup_in_db(list(words.keys())):
222 for trange in words[row.word_token]:
223 token = ICUToken.from_db_row(row, trange.penalty or 0.0)
225 if row.info['op'] in ('in', 'near'):
226 if trange.start == 0:
227 query.add_token(trange, qmod.TokenType.NEAR_ITEM, token)
229 if trange.start == 0 and trange.end == query.num_token_slots():
230 query.add_token(trange, qmod.TokenType.NEAR_ITEM, token)
232 query.add_token(trange, qmod.TokenType.QUALIFIER, token)
234 query.add_token(trange, DB_TO_TOKEN_TYPE[row.type], token)
236 self.add_extra_tokens(query, parts)
237 self.rerank_tokens(query, parts)
239 log().table_dump('Word tokens', _dump_word_tokens(query))
243 def normalize_text(self, text: str) -> str:
244 """ Bring the given text into a normalized form. That is the
245 standardized form search will work with. All information removed
246 at this stage is inevitably lost.
248 return cast(str, self.normalizer.transliterate(text)).strip('-: ')
250 def split_query(self, query: qmod.QueryStruct) -> Tuple[QueryParts, WordDict]:
251 """ Transliterate the phrases and split them into tokens.
253 Returns the list of transliterated tokens together with their
254 normalized form and a dictionary of words for lookup together
257 parts: QueryParts = []
259 words = defaultdict(list)
260 for phrase in query.source:
261 query.nodes[-1].ptype = phrase.ptype
262 phrase_split = re.split('([ :-])', phrase.text)
263 # The zip construct will give us the pairs of word/break from
264 # the regular expression split. As the split array ends on the
265 # final word, we simply use the fillvalue to even out the list and
266 # add the phrase break at the end.
267 for word, breakchar in zip_longest(*[iter(phrase_split)]*2, fillvalue=','):
270 trans = self.transliterator.transliterate(word)
272 for term in trans.split(' '):
274 parts.append(QueryPart(term, word,
275 PENALTY_IN_TOKEN_BREAK[qmod.BreakType.TOKEN]))
276 query.add_node(qmod.BreakType.TOKEN, phrase.ptype)
277 query.nodes[-1].btype = qmod.BreakType(breakchar)
278 parts[-1].penalty = PENALTY_IN_TOKEN_BREAK[qmod.BreakType(breakchar)]
280 for word, wrange in yield_words(parts, phrase_start):
281 words[word].append(wrange)
283 phrase_start = len(parts)
284 query.nodes[-1].btype = qmod.BreakType.END
288 async def lookup_in_db(self, words: List[str]) -> 'sa.Result[Any]':
289 """ Return the token information from the database for the
292 t = self.conn.t.meta.tables['word']
293 return await self.conn.execute(t.select().where(t.c.word_token.in_(words)))
295 def add_extra_tokens(self, query: qmod.QueryStruct, parts: QueryParts) -> None:
296 """ Add tokens to query that are not saved in the database.
298 for part, node, i in zip(parts, query.nodes, range(1000)):
299 if len(part.token) <= 4 and part.token.isdigit()\
300 and not node.has_tokens(i+1, qmod.TokenType.HOUSENUMBER):
301 query.add_token(qmod.TokenRange(i, i+1), qmod.TokenType.HOUSENUMBER,
302 ICUToken(penalty=0.5, token=0,
303 count=1, addr_count=1, lookup_word=part.token,
304 word_token=part.token, info=None))
306 def rerank_tokens(self, query: qmod.QueryStruct, parts: QueryParts) -> None:
307 """ Add penalties to tokens that depend on presence of other token.
309 for i, node, tlist in query.iter_token_lists():
310 if tlist.ttype == qmod.TokenType.POSTCODE:
311 for repl in node.starting:
312 if repl.end == tlist.end and repl.ttype != qmod.TokenType.POSTCODE \
313 and (repl.ttype != qmod.TokenType.HOUSENUMBER
314 or len(tlist.tokens[0].lookup_word) > 4):
315 repl.add_penalty(0.39)
316 elif (tlist.ttype == qmod.TokenType.HOUSENUMBER
317 and len(tlist.tokens[0].lookup_word) <= 3):
318 if any(c.isdigit() for c in tlist.tokens[0].lookup_word):
319 for repl in node.starting:
320 if repl.end == tlist.end and repl.ttype != qmod.TokenType.HOUSENUMBER:
321 repl.add_penalty(0.5 - tlist.tokens[0].penalty)
322 elif tlist.ttype not in (qmod.TokenType.COUNTRY, qmod.TokenType.PARTIAL):
323 norm = parts[i].normalized
324 for j in range(i + 1, tlist.end):
325 if node.btype != qmod.BreakType.TOKEN:
326 norm += ' ' + parts[j].normalized
327 for token in tlist.tokens:
328 cast(ICUToken, token).rematch(norm)
331 def _dump_transliterated(query: qmod.QueryStruct, parts: QueryParts) -> str:
332 out = query.nodes[0].btype.value
333 for node, part in zip(query.nodes[1:], parts):
334 out += part.token + node.btype.value
338 def _dump_word_tokens(query: qmod.QueryStruct) -> Iterator[List[Any]]:
339 yield ['type', 'token', 'word_token', 'lookup_word', 'penalty', 'count', 'info']
340 for node in query.nodes:
341 for tlist in node.starting:
342 for token in tlist.tokens:
343 t = cast(ICUToken, token)
344 yield [tlist.ttype.name, t.token, t.word_token or '',
345 t.lookup_word or '', t.penalty, t.count, t.info]
348 async def create_query_analyzer(conn: SearchConnection) -> AbstractQueryAnalyzer:
349 """ Create and set up a new query analyzer for a database based
350 on the ICU tokenizer.
352 out = ICUQueryAnalyzer(conn)