1 # SPDX-License-Identifier: GPL-3.0-or-later
3 # This file is part of Nominatim. (https://nominatim.org)
5 # Copyright (C) 2024 by the Nominatim developer community.
6 # For a full list of authors see the git log.
8 Implementation of query analysis for the ICU tokenizer.
10 from typing import Tuple, Dict, List, Optional, Iterator, Any, cast
11 from collections import defaultdict
15 from itertools import zip_longest
17 from icu import Transliterator
19 import sqlalchemy as sa
21 from ..errors import UsageError
22 from ..typing import SaRow
23 from ..sql.sqlalchemy_types import Json
24 from ..connection import SearchConnection
25 from ..logging import log
26 from . import query as qmod
27 from ..query_preprocessing.config import QueryConfig
28 from .query_analyzer_factory import AbstractQueryAnalyzer
32 'W': qmod.TokenType.WORD,
33 'w': qmod.TokenType.PARTIAL,
34 'H': qmod.TokenType.HOUSENUMBER,
35 'P': qmod.TokenType.POSTCODE,
36 'C': qmod.TokenType.COUNTRY
39 PENALTY_IN_TOKEN_BREAK = {
40 qmod.BreakType.START: 0.5,
41 qmod.BreakType.END: 0.5,
42 qmod.BreakType.PHRASE: 0.5,
43 qmod.BreakType.SOFT_PHRASE: 0.5,
44 qmod.BreakType.WORD: 0.1,
45 qmod.BreakType.PART: 0.0,
46 qmod.BreakType.TOKEN: 0.0
50 @dataclasses.dataclass
52 """ Normalized and transliterated form of a single term in the query.
54 When the term came out of a split during the transliteration,
55 the normalized string is the full word before transliteration.
56 Check the subsequent break type to figure out if the word is
59 Penalty is the break penalty for the break following the token.
66 QueryParts = List[QueryPart]
67 WordDict = Dict[str, List[qmod.TokenRange]]
70 def extract_words(terms: List[QueryPart], start: int, words: WordDict) -> None:
71 """ Add all combinations of words in the terms list after the
72 given position to the word list.
75 base_penalty = PENALTY_IN_TOKEN_BREAK[qmod.BreakType.WORD]
76 for first in range(start, total):
77 word = terms[first].token
78 penalty = base_penalty
79 words[word].append(qmod.TokenRange(first, first + 1, penalty=penalty))
80 for last in range(first + 1, min(first + 20, total)):
81 word = ' '.join((word, terms[last].token))
82 penalty += terms[last - 1].penalty
83 words[word].append(qmod.TokenRange(first, last + 1, penalty=penalty))
86 @dataclasses.dataclass
87 class ICUToken(qmod.Token):
88 """ Specialised token for ICU tokenizer.
91 info: Optional[Dict[str, Any]]
93 def get_category(self) -> Tuple[str, str]:
95 return self.info.get('class', ''), self.info.get('type', '')
97 def rematch(self, norm: str) -> None:
98 """ Check how well the token matches the given normalized string
99 and add a penalty, if necessary.
101 if not self.lookup_word:
104 seq = difflib.SequenceMatcher(a=self.lookup_word, b=norm)
106 for tag, afrom, ato, bfrom, bto in seq.get_opcodes():
107 if tag in ('delete', 'insert') and (afrom == 0 or ato == len(self.lookup_word)):
109 elif tag == 'replace':
110 distance += max((ato-afrom), (bto-bfrom))
112 distance += abs((ato-afrom) - (bto-bfrom))
113 self.penalty += (distance/len(self.lookup_word))
116 def from_db_row(row: SaRow, base_penalty: float = 0.0) -> 'ICUToken':
117 """ Create a ICUToken from the row of the word table.
119 count = 1 if row.info is None else row.info.get('count', 1)
120 addr_count = 1 if row.info is None else row.info.get('addr_count', 1)
122 penalty = base_penalty
125 elif row.type == 'W':
126 if len(row.word_token) == 1 and row.word_token == row.word:
127 penalty += 0.2 if row.word.isdigit() else 0.3
128 elif row.type == 'H':
129 penalty += sum(0.1 for c in row.word_token if c != ' ' and not c.isdigit())
130 if all(not c.isdigit() for c in row.word_token):
131 penalty += 0.2 * (len(row.word_token) - 1)
132 elif row.type == 'C':
133 if len(row.word_token) == 1:
137 lookup_word = row.word
139 lookup_word = row.info.get('lookup', row.word)
141 lookup_word = lookup_word.split('@', 1)[0]
143 lookup_word = row.word_token
145 return ICUToken(penalty=penalty, token=row.word_id, count=max(1, count),
146 lookup_word=lookup_word,
147 word_token=row.word_token, info=row.info,
148 addr_count=max(1, addr_count))
151 class ICUQueryAnalyzer(AbstractQueryAnalyzer):
152 """ Converter for query strings into a tokenized query
153 using the tokens created by a ICU tokenizer.
155 def __init__(self, conn: SearchConnection) -> None:
158 async def setup(self) -> None:
159 """ Set up static data structures needed for the analysis.
161 async def _make_normalizer() -> Any:
162 rules = await self.conn.get_property('tokenizer_import_normalisation')
163 return Transliterator.createFromRules("normalization", rules)
165 self.normalizer = await self.conn.get_cached_value('ICUTOK', 'normalizer',
168 async def _make_transliterator() -> Any:
169 rules = await self.conn.get_property('tokenizer_import_transliteration')
170 return Transliterator.createFromRules("transliteration", rules)
172 self.transliterator = await self.conn.get_cached_value('ICUTOK', 'transliterator',
173 _make_transliterator)
175 await self._setup_preprocessing()
177 if 'word' not in self.conn.t.meta.tables:
178 sa.Table('word', self.conn.t.meta,
179 sa.Column('word_id', sa.Integer),
180 sa.Column('word_token', sa.Text, nullable=False),
181 sa.Column('type', sa.Text, nullable=False),
182 sa.Column('word', sa.Text),
183 sa.Column('info', Json))
185 async def _setup_preprocessing(self) -> None:
186 """ Load the rules for preprocessing and set up the handlers.
189 rules = self.conn.config.load_sub_configuration('icu_tokenizer.yaml',
190 config='TOKENIZER_CONFIG')
191 preprocessing_rules = rules.get('query-preprocessing', [])
193 self.preprocessors = []
195 for func in preprocessing_rules:
196 if 'step' not in func:
197 raise UsageError("Preprocessing rule is missing the 'step' attribute.")
198 if not isinstance(func['step'], str):
199 raise UsageError("'step' attribute must be a simple string.")
201 module = self.conn.config.load_plugin_module(
202 func['step'], 'nominatim_api.query_preprocessing')
203 self.preprocessors.append(
204 module.create(QueryConfig(func).set_normalizer(self.normalizer)))
206 async def analyze_query(self, phrases: List[qmod.Phrase]) -> qmod.QueryStruct:
207 """ Analyze the given list of phrases and return the
210 log().section('Analyze query (using ICU tokenizer)')
211 for func in self.preprocessors:
212 phrases = func(phrases)
213 query = qmod.QueryStruct(phrases)
215 log().var_dump('Normalized query', query.source)
219 parts, words = self.split_query(query)
220 log().var_dump('Transliterated query', lambda: _dump_transliterated(query, parts))
222 for row in await self.lookup_in_db(list(words.keys())):
223 for trange in words[row.word_token]:
224 token = ICUToken.from_db_row(row, trange.penalty or 0.0)
226 if row.info['op'] in ('in', 'near'):
227 if trange.start == 0:
228 query.add_token(trange, qmod.TokenType.NEAR_ITEM, token)
230 if trange.start == 0 and trange.end == query.num_token_slots():
231 query.add_token(trange, qmod.TokenType.NEAR_ITEM, token)
233 query.add_token(trange, qmod.TokenType.QUALIFIER, token)
235 query.add_token(trange, DB_TO_TOKEN_TYPE[row.type], token)
237 self.add_extra_tokens(query, parts)
238 self.rerank_tokens(query, parts)
240 log().table_dump('Word tokens', _dump_word_tokens(query))
244 def normalize_text(self, text: str) -> str:
245 """ Bring the given text into a normalized form. That is the
246 standardized form search will work with. All information removed
247 at this stage is inevitably lost.
249 return cast(str, self.normalizer.transliterate(text)).strip('-: ')
251 def split_query(self, query: qmod.QueryStruct) -> Tuple[QueryParts, WordDict]:
252 """ Transliterate the phrases and split them into tokens.
254 Returns the list of transliterated tokens together with their
255 normalized form and a dictionary of words for lookup together
258 parts: QueryParts = []
260 words: WordDict = defaultdict(list)
261 for phrase in query.source:
262 query.nodes[-1].ptype = phrase.ptype
263 phrase_split = re.split('([ :-])', phrase.text)
264 # The zip construct will give us the pairs of word/break from
265 # the regular expression split. As the split array ends on the
266 # final word, we simply use the fillvalue to even out the list and
267 # add the phrase break at the end.
268 for word, breakchar in zip_longest(*[iter(phrase_split)]*2, fillvalue=','):
271 trans = self.transliterator.transliterate(word)
273 for term in trans.split(' '):
275 parts.append(QueryPart(term, word,
276 PENALTY_IN_TOKEN_BREAK[qmod.BreakType.TOKEN]))
277 query.add_node(qmod.BreakType.TOKEN, phrase.ptype)
278 query.nodes[-1].btype = qmod.BreakType(breakchar)
279 parts[-1].penalty = PENALTY_IN_TOKEN_BREAK[qmod.BreakType(breakchar)]
281 extract_words(parts, phrase_start, words)
283 phrase_start = len(parts)
284 query.nodes[-1].btype = qmod.BreakType.END
288 async def lookup_in_db(self, words: List[str]) -> 'sa.Result[Any]':
289 """ Return the token information from the database for the
292 t = self.conn.t.meta.tables['word']
293 return await self.conn.execute(t.select().where(t.c.word_token.in_(words)))
295 def add_extra_tokens(self, query: qmod.QueryStruct, parts: QueryParts) -> None:
296 """ Add tokens to query that are not saved in the database.
298 for part, node, i in zip(parts, query.nodes, range(1000)):
299 if len(part.token) <= 4 and part.token.isdigit()\
300 and not node.has_tokens(i+1, qmod.TokenType.HOUSENUMBER):
301 query.add_token(qmod.TokenRange(i, i+1), qmod.TokenType.HOUSENUMBER,
302 ICUToken(penalty=0.5, token=0,
303 count=1, addr_count=1, lookup_word=part.token,
304 word_token=part.token, info=None))
306 def rerank_tokens(self, query: qmod.QueryStruct, parts: QueryParts) -> None:
307 """ Add penalties to tokens that depend on presence of other token.
309 for i, node, tlist in query.iter_token_lists():
310 if tlist.ttype == qmod.TokenType.POSTCODE:
311 for repl in node.starting:
312 if repl.end == tlist.end and repl.ttype != qmod.TokenType.POSTCODE \
313 and (repl.ttype != qmod.TokenType.HOUSENUMBER
314 or len(tlist.tokens[0].lookup_word) > 4):
315 repl.add_penalty(0.39)
316 elif (tlist.ttype == qmod.TokenType.HOUSENUMBER
317 and len(tlist.tokens[0].lookup_word) <= 3):
318 if any(c.isdigit() for c in tlist.tokens[0].lookup_word):
319 for repl in node.starting:
320 if repl.end == tlist.end and repl.ttype != qmod.TokenType.HOUSENUMBER:
321 repl.add_penalty(0.5 - tlist.tokens[0].penalty)
322 elif tlist.ttype not in (qmod.TokenType.COUNTRY, qmod.TokenType.PARTIAL):
323 norm = parts[i].normalized
324 for j in range(i + 1, tlist.end):
325 if node.btype != qmod.BreakType.TOKEN:
326 norm += ' ' + parts[j].normalized
327 for token in tlist.tokens:
328 cast(ICUToken, token).rematch(norm)
331 def _dump_transliterated(query: qmod.QueryStruct, parts: QueryParts) -> str:
332 out = query.nodes[0].btype.value
333 for node, part in zip(query.nodes[1:], parts):
334 out += part.token + node.btype.value
338 def _dump_word_tokens(query: qmod.QueryStruct) -> Iterator[List[Any]]:
339 yield ['type', 'token', 'word_token', 'lookup_word', 'penalty', 'count', 'info']
340 for node in query.nodes:
341 for tlist in node.starting:
342 for token in tlist.tokens:
343 t = cast(ICUToken, token)
344 yield [tlist.ttype.name, t.token, t.word_token or '',
345 t.lookup_word or '', t.penalty, t.count, t.info]
348 async def create_query_analyzer(conn: SearchConnection) -> AbstractQueryAnalyzer:
349 """ Create and set up a new query analyzer for a database based
350 on the ICU tokenizer.
352 out = ICUQueryAnalyzer(conn)