1 # SPDX-License-Identifier: GPL-3.0-or-later
3 # This file is part of Nominatim. (https://nominatim.org)
5 # Copyright (C) 2024 by the Nominatim developer community.
6 # For a full list of authors see the git log.
8 Implementation of query analysis for the ICU tokenizer.
10 from typing import Tuple, Dict, List, Optional, Iterator, Any, cast
11 from collections import defaultdict
15 from itertools import zip_longest
17 from icu import Transliterator
19 import sqlalchemy as sa
21 from ..errors import UsageError
22 from ..typing import SaRow
23 from ..sql.sqlalchemy_types import Json
24 from ..connection import SearchConnection
25 from ..logging import log
26 from . import query as qmod
27 from ..query_preprocessing.config import QueryConfig
28 from .query_analyzer_factory import AbstractQueryAnalyzer
33 'w': qmod.TOKEN_PARTIAL,
34 'H': qmod.TOKEN_HOUSENUMBER,
35 'P': qmod.TOKEN_POSTCODE,
36 'C': qmod.TOKEN_COUNTRY
39 PENALTY_IN_TOKEN_BREAK = {
40 qmod.BREAK_START: 0.5,
42 qmod.BREAK_PHRASE: 0.5,
43 qmod.BREAK_SOFT_PHRASE: 0.5,
50 @dataclasses.dataclass
52 """ Normalized and transliterated form of a single term in the query.
54 When the term came out of a split during the transliteration,
55 the normalized string is the full word before transliteration.
56 Check the subsequent break type to figure out if the word is
59 Penalty is the break penalty for the break following the token.
66 QueryParts = List[QueryPart]
67 WordDict = Dict[str, List[qmod.TokenRange]]
70 def extract_words(terms: List[QueryPart], start: int, words: WordDict) -> None:
71 """ Add all combinations of words in the terms list after the
72 given position to the word list.
75 base_penalty = PENALTY_IN_TOKEN_BREAK[qmod.BREAK_WORD]
76 for first in range(start, total):
77 word = terms[first].token
78 penalty = base_penalty
79 words[word].append(qmod.TokenRange(first, first + 1, penalty=penalty))
80 for last in range(first + 1, min(first + 20, total)):
81 word = ' '.join((word, terms[last].token))
82 penalty += terms[last - 1].penalty
83 words[word].append(qmod.TokenRange(first, last + 1, penalty=penalty))
86 @dataclasses.dataclass
87 class ICUToken(qmod.Token):
88 """ Specialised token for ICU tokenizer.
91 info: Optional[Dict[str, Any]]
93 def get_category(self) -> Tuple[str, str]:
95 return self.info.get('class', ''), self.info.get('type', '')
97 def rematch(self, norm: str) -> None:
98 """ Check how well the token matches the given normalized string
99 and add a penalty, if necessary.
101 if not self.lookup_word:
104 seq = difflib.SequenceMatcher(a=self.lookup_word, b=norm)
106 for tag, afrom, ato, bfrom, bto in seq.get_opcodes():
107 if tag in ('delete', 'insert') and (afrom == 0 or ato == len(self.lookup_word)):
109 elif tag == 'replace':
110 distance += max((ato-afrom), (bto-bfrom))
112 distance += abs((ato-afrom) - (bto-bfrom))
113 self.penalty += (distance/len(self.lookup_word))
116 def from_db_row(row: SaRow, base_penalty: float = 0.0) -> 'ICUToken':
117 """ Create a ICUToken from the row of the word table.
119 count = 1 if row.info is None else row.info.get('count', 1)
120 addr_count = 1 if row.info is None else row.info.get('addr_count', 1)
122 penalty = base_penalty
125 elif row.type == 'W':
126 if len(row.word_token) == 1 and row.word_token == row.word:
127 penalty += 0.2 if row.word.isdigit() else 0.3
128 elif row.type == 'H':
129 penalty += sum(0.1 for c in row.word_token if c != ' ' and not c.isdigit())
130 if all(not c.isdigit() for c in row.word_token):
131 penalty += 0.2 * (len(row.word_token) - 1)
132 elif row.type == 'C':
133 if len(row.word_token) == 1:
137 lookup_word = row.word
139 lookup_word = row.info.get('lookup', row.word)
141 lookup_word = lookup_word.split('@', 1)[0]
143 lookup_word = row.word_token
145 return ICUToken(penalty=penalty, token=row.word_id, count=max(1, count),
146 lookup_word=lookup_word,
147 word_token=row.word_token, info=row.info,
148 addr_count=max(1, addr_count))
151 class ICUQueryAnalyzer(AbstractQueryAnalyzer):
152 """ Converter for query strings into a tokenized query
153 using the tokens created by a ICU tokenizer.
155 def __init__(self, conn: SearchConnection) -> None:
158 async def setup(self) -> None:
159 """ Set up static data structures needed for the analysis.
161 async def _make_normalizer() -> Any:
162 rules = await self.conn.get_property('tokenizer_import_normalisation')
163 return Transliterator.createFromRules("normalization", rules)
165 self.normalizer = await self.conn.get_cached_value('ICUTOK', 'normalizer',
168 async def _make_transliterator() -> Any:
169 rules = await self.conn.get_property('tokenizer_import_transliteration')
170 return Transliterator.createFromRules("transliteration", rules)
172 self.transliterator = await self.conn.get_cached_value('ICUTOK', 'transliterator',
173 _make_transliterator)
175 await self._setup_preprocessing()
177 if 'word' not in self.conn.t.meta.tables:
178 sa.Table('word', self.conn.t.meta,
179 sa.Column('word_id', sa.Integer),
180 sa.Column('word_token', sa.Text, nullable=False),
181 sa.Column('type', sa.Text, nullable=False),
182 sa.Column('word', sa.Text),
183 sa.Column('info', Json))
185 async def _setup_preprocessing(self) -> None:
186 """ Load the rules for preprocessing and set up the handlers.
189 rules = self.conn.config.load_sub_configuration('icu_tokenizer.yaml',
190 config='TOKENIZER_CONFIG')
191 preprocessing_rules = rules.get('query-preprocessing', [])
193 self.preprocessors = []
195 for func in preprocessing_rules:
196 if 'step' not in func:
197 raise UsageError("Preprocessing rule is missing the 'step' attribute.")
198 if not isinstance(func['step'], str):
199 raise UsageError("'step' attribute must be a simple string.")
201 module = self.conn.config.load_plugin_module(
202 func['step'], 'nominatim_api.query_preprocessing')
203 self.preprocessors.append(
204 module.create(QueryConfig(func).set_normalizer(self.normalizer)))
206 async def analyze_query(self, phrases: List[qmod.Phrase]) -> qmod.QueryStruct:
207 """ Analyze the given list of phrases and return the
210 log().section('Analyze query (using ICU tokenizer)')
211 for func in self.preprocessors:
212 phrases = func(phrases)
214 if len(phrases) == 1 \
215 and phrases[0].text.count(' ') > 3 \
216 and max(len(s) for s in phrases[0].text.split()) < 3:
219 query = qmod.QueryStruct(phrases)
221 log().var_dump('Normalized query', query.source)
225 parts, words = self.split_query(query)
226 log().var_dump('Transliterated query', lambda: _dump_transliterated(query, parts))
228 for row in await self.lookup_in_db(list(words.keys())):
229 for trange in words[row.word_token]:
230 token = ICUToken.from_db_row(row, trange.penalty or 0.0)
232 if row.info['op'] in ('in', 'near'):
233 if trange.start == 0:
234 query.add_token(trange, qmod.TOKEN_NEAR_ITEM, token)
236 if trange.start == 0 and trange.end == query.num_token_slots():
237 query.add_token(trange, qmod.TOKEN_NEAR_ITEM, token)
239 query.add_token(trange, qmod.TOKEN_QUALIFIER, token)
241 query.add_token(trange, DB_TO_TOKEN_TYPE[row.type], token)
243 self.add_extra_tokens(query, parts)
244 self.rerank_tokens(query, parts)
246 log().table_dump('Word tokens', _dump_word_tokens(query))
250 def normalize_text(self, text: str) -> str:
251 """ Bring the given text into a normalized form. That is the
252 standardized form search will work with. All information removed
253 at this stage is inevitably lost.
255 return cast(str, self.normalizer.transliterate(text)).strip('-: ')
257 def split_query(self, query: qmod.QueryStruct) -> Tuple[QueryParts, WordDict]:
258 """ Transliterate the phrases and split them into tokens.
260 Returns the list of transliterated tokens together with their
261 normalized form and a dictionary of words for lookup together
264 parts: QueryParts = []
266 words: WordDict = defaultdict(list)
267 for phrase in query.source:
268 query.nodes[-1].ptype = phrase.ptype
269 phrase_split = re.split('([ :-])', phrase.text)
270 # The zip construct will give us the pairs of word/break from
271 # the regular expression split. As the split array ends on the
272 # final word, we simply use the fillvalue to even out the list and
273 # add the phrase break at the end.
274 for word, breakchar in zip_longest(*[iter(phrase_split)]*2, fillvalue=','):
277 trans = self.transliterator.transliterate(word)
279 for term in trans.split(' '):
281 parts.append(QueryPart(term, word,
282 PENALTY_IN_TOKEN_BREAK[qmod.BREAK_TOKEN]))
283 query.add_node(qmod.BREAK_TOKEN, phrase.ptype)
284 query.nodes[-1].btype = breakchar
285 parts[-1].penalty = PENALTY_IN_TOKEN_BREAK[breakchar]
287 extract_words(parts, phrase_start, words)
289 phrase_start = len(parts)
290 query.nodes[-1].btype = qmod.BREAK_END
294 async def lookup_in_db(self, words: List[str]) -> 'sa.Result[Any]':
295 """ Return the token information from the database for the
298 t = self.conn.t.meta.tables['word']
299 return await self.conn.execute(t.select().where(t.c.word_token.in_(words)))
301 def add_extra_tokens(self, query: qmod.QueryStruct, parts: QueryParts) -> None:
302 """ Add tokens to query that are not saved in the database.
304 for part, node, i in zip(parts, query.nodes, range(1000)):
305 if len(part.token) <= 4 and part.token.isdigit()\
306 and not node.has_tokens(i+1, qmod.TOKEN_HOUSENUMBER):
307 query.add_token(qmod.TokenRange(i, i+1), qmod.TOKEN_HOUSENUMBER,
308 ICUToken(penalty=0.5, token=0,
309 count=1, addr_count=1, lookup_word=part.token,
310 word_token=part.token, info=None))
312 def rerank_tokens(self, query: qmod.QueryStruct, parts: QueryParts) -> None:
313 """ Add penalties to tokens that depend on presence of other token.
315 for i, node, tlist in query.iter_token_lists():
316 if tlist.ttype == qmod.TOKEN_POSTCODE:
317 for repl in node.starting:
318 if repl.end == tlist.end and repl.ttype != qmod.TOKEN_POSTCODE \
319 and (repl.ttype != qmod.TOKEN_HOUSENUMBER
320 or len(tlist.tokens[0].lookup_word) > 4):
321 repl.add_penalty(0.39)
322 elif (tlist.ttype == qmod.TOKEN_HOUSENUMBER
323 and len(tlist.tokens[0].lookup_word) <= 3):
324 if any(c.isdigit() for c in tlist.tokens[0].lookup_word):
325 for repl in node.starting:
326 if repl.end == tlist.end and repl.ttype != qmod.TOKEN_HOUSENUMBER:
327 repl.add_penalty(0.5 - tlist.tokens[0].penalty)
328 elif tlist.ttype not in (qmod.TOKEN_COUNTRY, qmod.TOKEN_PARTIAL):
329 norm = parts[i].normalized
330 for j in range(i + 1, tlist.end):
331 if node.btype != qmod.BREAK_TOKEN:
332 norm += ' ' + parts[j].normalized
333 for token in tlist.tokens:
334 cast(ICUToken, token).rematch(norm)
337 def _dump_transliterated(query: qmod.QueryStruct, parts: QueryParts) -> str:
338 out = query.nodes[0].btype
339 for node, part in zip(query.nodes[1:], parts):
340 out += part.token + node.btype
344 def _dump_word_tokens(query: qmod.QueryStruct) -> Iterator[List[Any]]:
345 yield ['type', 'token', 'word_token', 'lookup_word', 'penalty', 'count', 'info']
346 for node in query.nodes:
347 for tlist in node.starting:
348 for token in tlist.tokens:
349 t = cast(ICUToken, token)
350 yield [tlist.ttype, t.token, t.word_token or '',
351 t.lookup_word or '', t.penalty, t.count, t.info]
354 async def create_query_analyzer(conn: SearchConnection) -> AbstractQueryAnalyzer:
355 """ Create and set up a new query analyzer for a database based
356 on the ICU tokenizer.
358 out = ICUQueryAnalyzer(conn)