1 # SPDX-License-Identifier: GPL-3.0-or-later
3 # This file is part of Nominatim. (https://nominatim.org)
5 # Copyright (C) 2024 by the Nominatim developer community.
6 # For a full list of authors see the git log.
8 Implementation of query analysis for the ICU tokenizer.
10 from typing import Tuple, Dict, List, Optional, Iterator, Any, cast
11 from collections import defaultdict
15 from itertools import zip_longest
17 from icu import Transliterator
19 import sqlalchemy as sa
21 from ..errors import UsageError
22 from ..typing import SaRow
23 from ..sql.sqlalchemy_types import Json
24 from ..connection import SearchConnection
25 from ..logging import log
26 from . import query as qmod
27 from ..query_preprocessing.config import QueryConfig
28 from .query_analyzer_factory import AbstractQueryAnalyzer
32 'W': qmod.TokenType.WORD,
33 'w': qmod.TokenType.PARTIAL,
34 'H': qmod.TokenType.HOUSENUMBER,
35 'P': qmod.TokenType.POSTCODE,
36 'C': qmod.TokenType.COUNTRY
39 PENALTY_IN_TOKEN_BREAK = {
40 qmod.BreakType.START: 0.5,
41 qmod.BreakType.END: 0.5,
42 qmod.BreakType.PHRASE: 0.5,
43 qmod.BreakType.SOFT_PHRASE: 0.5,
44 qmod.BreakType.WORD: 0.1,
45 qmod.BreakType.PART: 0.0,
46 qmod.BreakType.TOKEN: 0.0
50 @dataclasses.dataclass
52 """ Normalized and transliterated form of a single term in the query.
53 When the term came out of a split during the transliteration,
54 the normalized string is the full word before transliteration.
55 The word number keeps track of the word before transliteration
56 and can be used to identify partial transliterated terms.
57 Penalty is the break penalty for the break following the token.
65 QueryParts = List[QueryPart]
66 WordDict = Dict[str, List[qmod.TokenRange]]
69 def yield_words(terms: List[QueryPart], start: int) -> Iterator[Tuple[str, qmod.TokenRange]]:
70 """ Return all combinations of words in the terms list after the
74 for first in range(start, total):
75 word = terms[first].token
76 penalty = PENALTY_IN_TOKEN_BREAK[qmod.BreakType.WORD]
77 yield word, qmod.TokenRange(first, first + 1, penalty=penalty)
78 for last in range(first + 1, min(first + 20, total)):
79 word = ' '.join((word, terms[last].token))
80 penalty += terms[last - 1].penalty
81 yield word, qmod.TokenRange(first, last + 1, penalty=penalty)
84 @dataclasses.dataclass
85 class ICUToken(qmod.Token):
86 """ Specialised token for ICU tokenizer.
89 info: Optional[Dict[str, Any]]
91 def get_category(self) -> Tuple[str, str]:
93 return self.info.get('class', ''), self.info.get('type', '')
95 def rematch(self, norm: str) -> None:
96 """ Check how well the token matches the given normalized string
97 and add a penalty, if necessary.
99 if not self.lookup_word:
102 seq = difflib.SequenceMatcher(a=self.lookup_word, b=norm)
104 for tag, afrom, ato, bfrom, bto in seq.get_opcodes():
105 if tag in ('delete', 'insert') and (afrom == 0 or ato == len(self.lookup_word)):
107 elif tag == 'replace':
108 distance += max((ato-afrom), (bto-bfrom))
110 distance += abs((ato-afrom) - (bto-bfrom))
111 self.penalty += (distance/len(self.lookup_word))
114 def from_db_row(row: SaRow, base_penalty: float = 0.0) -> 'ICUToken':
115 """ Create a ICUToken from the row of the word table.
117 count = 1 if row.info is None else row.info.get('count', 1)
118 addr_count = 1 if row.info is None else row.info.get('addr_count', 1)
120 penalty = base_penalty
123 elif row.type == 'W':
124 if len(row.word_token) == 1 and row.word_token == row.word:
125 penalty += 0.2 if row.word.isdigit() else 0.3
126 elif row.type == 'H':
127 penalty += sum(0.1 for c in row.word_token if c != ' ' and not c.isdigit())
128 if all(not c.isdigit() for c in row.word_token):
129 penalty += 0.2 * (len(row.word_token) - 1)
130 elif row.type == 'C':
131 if len(row.word_token) == 1:
135 lookup_word = row.word
137 lookup_word = row.info.get('lookup', row.word)
139 lookup_word = lookup_word.split('@', 1)[0]
141 lookup_word = row.word_token
143 return ICUToken(penalty=penalty, token=row.word_id, count=max(1, count),
144 lookup_word=lookup_word,
145 word_token=row.word_token, info=row.info,
146 addr_count=max(1, addr_count))
149 class ICUQueryAnalyzer(AbstractQueryAnalyzer):
150 """ Converter for query strings into a tokenized query
151 using the tokens created by a ICU tokenizer.
153 def __init__(self, conn: SearchConnection) -> None:
156 async def setup(self) -> None:
157 """ Set up static data structures needed for the analysis.
159 async def _make_normalizer() -> Any:
160 rules = await self.conn.get_property('tokenizer_import_normalisation')
161 return Transliterator.createFromRules("normalization", rules)
163 self.normalizer = await self.conn.get_cached_value('ICUTOK', 'normalizer',
166 async def _make_transliterator() -> Any:
167 rules = await self.conn.get_property('tokenizer_import_transliteration')
168 return Transliterator.createFromRules("transliteration", rules)
170 self.transliterator = await self.conn.get_cached_value('ICUTOK', 'transliterator',
171 _make_transliterator)
173 await self._setup_preprocessing()
175 if 'word' not in self.conn.t.meta.tables:
176 sa.Table('word', self.conn.t.meta,
177 sa.Column('word_id', sa.Integer),
178 sa.Column('word_token', sa.Text, nullable=False),
179 sa.Column('type', sa.Text, nullable=False),
180 sa.Column('word', sa.Text),
181 sa.Column('info', Json))
183 async def _setup_preprocessing(self) -> None:
184 """ Load the rules for preprocessing and set up the handlers.
187 rules = self.conn.config.load_sub_configuration('icu_tokenizer.yaml',
188 config='TOKENIZER_CONFIG')
189 preprocessing_rules = rules.get('query-preprocessing', [])
191 self.preprocessors = []
193 for func in preprocessing_rules:
194 if 'step' not in func:
195 raise UsageError("Preprocessing rule is missing the 'step' attribute.")
196 if not isinstance(func['step'], str):
197 raise UsageError("'step' attribute must be a simple string.")
199 module = self.conn.config.load_plugin_module(
200 func['step'], 'nominatim_api.query_preprocessing')
201 self.preprocessors.append(
202 module.create(QueryConfig(func).set_normalizer(self.normalizer)))
204 async def analyze_query(self, phrases: List[qmod.Phrase]) -> qmod.QueryStruct:
205 """ Analyze the given list of phrases and return the
208 log().section('Analyze query (using ICU tokenizer)')
209 for func in self.preprocessors:
210 phrases = func(phrases)
211 query = qmod.QueryStruct(phrases)
213 log().var_dump('Normalized query', query.source)
217 parts, words = self.split_query(query)
218 log().var_dump('Transliterated query', lambda: _dump_transliterated(query, parts))
220 for row in await self.lookup_in_db(list(words.keys())):
221 for trange in words[row.word_token]:
222 token = ICUToken.from_db_row(row, trange.penalty or 0.0)
224 if row.info['op'] in ('in', 'near'):
225 if trange.start == 0:
226 query.add_token(trange, qmod.TokenType.NEAR_ITEM, token)
228 if trange.start == 0 and trange.end == query.num_token_slots():
229 query.add_token(trange, qmod.TokenType.NEAR_ITEM, token)
231 query.add_token(trange, qmod.TokenType.QUALIFIER, token)
233 query.add_token(trange, DB_TO_TOKEN_TYPE[row.type], token)
235 self.add_extra_tokens(query, parts)
236 self.rerank_tokens(query, parts)
238 log().table_dump('Word tokens', _dump_word_tokens(query))
242 def normalize_text(self, text: str) -> str:
243 """ Bring the given text into a normalized form. That is the
244 standardized form search will work with. All information removed
245 at this stage is inevitably lost.
247 return cast(str, self.normalizer.transliterate(text))
249 def split_query(self, query: qmod.QueryStruct) -> Tuple[QueryParts, WordDict]:
250 """ Transliterate the phrases and split them into tokens.
252 Returns the list of transliterated tokens together with their
253 normalized form and a dictionary of words for lookup together
256 parts: QueryParts = []
258 words = defaultdict(list)
260 for phrase in query.source:
261 query.nodes[-1].ptype = phrase.ptype
262 phrase_split = re.split('([ :-])', phrase.text)
263 # The zip construct will give us the pairs of word/break from
264 # the regular expression split. As the split array ends on the
265 # final word, we simply use the fillvalue to even out the list and
266 # add the phrase break at the end.
267 for word, breakchar in zip_longest(*[iter(phrase_split)]*2, fillvalue=','):
270 trans = self.transliterator.transliterate(word)
272 for term in trans.split(' '):
274 parts.append(QueryPart(term, word, wordnr,
275 PENALTY_IN_TOKEN_BREAK[qmod.BreakType.TOKEN]))
276 query.add_node(qmod.BreakType.TOKEN, phrase.ptype)
277 query.nodes[-1].btype = qmod.BreakType(breakchar)
278 parts[-1].penalty = PENALTY_IN_TOKEN_BREAK[qmod.BreakType(breakchar)]
281 for word, wrange in yield_words(parts, phrase_start):
282 words[word].append(wrange)
284 phrase_start = len(parts)
285 query.nodes[-1].btype = qmod.BreakType.END
289 async def lookup_in_db(self, words: List[str]) -> 'sa.Result[Any]':
290 """ Return the token information from the database for the
293 t = self.conn.t.meta.tables['word']
294 return await self.conn.execute(t.select().where(t.c.word_token.in_(words)))
296 def add_extra_tokens(self, query: qmod.QueryStruct, parts: QueryParts) -> None:
297 """ Add tokens to query that are not saved in the database.
299 for part, node, i in zip(parts, query.nodes, range(1000)):
300 if len(part.token) <= 4 and part.token.isdigit()\
301 and not node.has_tokens(i+1, qmod.TokenType.HOUSENUMBER):
302 query.add_token(qmod.TokenRange(i, i+1), qmod.TokenType.HOUSENUMBER,
303 ICUToken(penalty=0.5, token=0,
304 count=1, addr_count=1, lookup_word=part.token,
305 word_token=part.token, info=None))
307 def rerank_tokens(self, query: qmod.QueryStruct, parts: QueryParts) -> None:
308 """ Add penalties to tokens that depend on presence of other token.
310 for i, node, tlist in query.iter_token_lists():
311 if tlist.ttype == qmod.TokenType.POSTCODE:
312 for repl in node.starting:
313 if repl.end == tlist.end and repl.ttype != qmod.TokenType.POSTCODE \
314 and (repl.ttype != qmod.TokenType.HOUSENUMBER
315 or len(tlist.tokens[0].lookup_word) > 4):
316 repl.add_penalty(0.39)
317 elif (tlist.ttype == qmod.TokenType.HOUSENUMBER
318 and len(tlist.tokens[0].lookup_word) <= 3):
319 if any(c.isdigit() for c in tlist.tokens[0].lookup_word):
320 for repl in node.starting:
321 if repl.end == tlist.end and repl.ttype != qmod.TokenType.HOUSENUMBER:
322 repl.add_penalty(0.5 - tlist.tokens[0].penalty)
323 elif tlist.ttype not in (qmod.TokenType.COUNTRY, qmod.TokenType.PARTIAL):
324 norm = parts[i].normalized
325 for j in range(i + 1, tlist.end):
326 if parts[j - 1].word_number != parts[j].word_number:
327 norm += ' ' + parts[j].normalized
328 for token in tlist.tokens:
329 cast(ICUToken, token).rematch(norm)
332 def _dump_transliterated(query: qmod.QueryStruct, parts: QueryParts) -> str:
333 out = query.nodes[0].btype.value
334 for node, part in zip(query.nodes[1:], parts):
335 out += part.token + node.btype.value
339 def _dump_word_tokens(query: qmod.QueryStruct) -> Iterator[List[Any]]:
340 yield ['type', 'token', 'word_token', 'lookup_word', 'penalty', 'count', 'info']
341 for node in query.nodes:
342 for tlist in node.starting:
343 for token in tlist.tokens:
344 t = cast(ICUToken, token)
345 yield [tlist.ttype.name, t.token, t.word_token or '',
346 t.lookup_word or '', t.penalty, t.count, t.info]
349 async def create_query_analyzer(conn: SearchConnection) -> AbstractQueryAnalyzer:
350 """ Create and set up a new query analyzer for a database based
351 on the ICU tokenizer.
353 out = ICUQueryAnalyzer(conn)