1 # SPDX-License-Identifier: GPL-3.0-or-later
3 # This file is part of Nominatim. (https://nominatim.org)
5 # Copyright (C) 2024 by the Nominatim developer community.
6 # For a full list of authors see the git log.
8 Implementation of query analysis for the ICU tokenizer.
10 from typing import Tuple, Dict, List, Optional, Iterator, Any, cast
11 from collections import defaultdict
15 from itertools import zip_longest
17 from icu import Transliterator
19 import sqlalchemy as sa
21 from ..errors import UsageError
22 from ..typing import SaRow
23 from ..sql.sqlalchemy_types import Json
24 from ..connection import SearchConnection
25 from ..logging import log
26 from . import query as qmod
27 from ..query_preprocessing.config import QueryConfig
28 from .query_analyzer_factory import AbstractQueryAnalyzer
33 'w': qmod.TOKEN_PARTIAL,
34 'H': qmod.TOKEN_HOUSENUMBER,
35 'P': qmod.TOKEN_POSTCODE,
36 'C': qmod.TOKEN_COUNTRY
39 PENALTY_IN_TOKEN_BREAK = {
40 qmod.BREAK_START: 0.5,
42 qmod.BREAK_PHRASE: 0.5,
43 qmod.BREAK_SOFT_PHRASE: 0.5,
50 WordDict = Dict[str, List[qmod.TokenRange]]
53 def extract_words(query: qmod.QueryStruct, start: int, words: WordDict) -> None:
54 """ Add all combinations of words in the terms list starting with
55 the term leading into node 'start'.
57 The words found will be added into the 'words' dictionary with
58 their start and end position.
62 base_penalty = PENALTY_IN_TOKEN_BREAK[qmod.BREAK_WORD]
63 for first in range(start, total):
64 word = nodes[first].term_lookup
65 penalty = base_penalty
66 words[word].append(qmod.TokenRange(first - 1, first, penalty=penalty))
67 for last in range(first + 1, min(first + 20, total)):
68 word = ' '.join((word, nodes[last].term_lookup))
69 penalty += nodes[last - 1].penalty
70 words[word].append(qmod.TokenRange(first - 1, last, penalty=penalty))
73 @dataclasses.dataclass
74 class ICUToken(qmod.Token):
75 """ Specialised token for ICU tokenizer.
78 info: Optional[Dict[str, Any]]
80 def get_category(self) -> Tuple[str, str]:
82 return self.info.get('class', ''), self.info.get('type', '')
84 def rematch(self, norm: str) -> None:
85 """ Check how well the token matches the given normalized string
86 and add a penalty, if necessary.
88 if not self.lookup_word:
91 seq = difflib.SequenceMatcher(a=self.lookup_word, b=norm)
93 for tag, afrom, ato, bfrom, bto in seq.get_opcodes():
94 if tag in ('delete', 'insert') and (afrom == 0 or ato == len(self.lookup_word)):
96 elif tag == 'replace':
97 distance += max((ato-afrom), (bto-bfrom))
99 distance += abs((ato-afrom) - (bto-bfrom))
100 self.penalty += (distance/len(self.lookup_word))
103 def from_db_row(row: SaRow, base_penalty: float = 0.0) -> 'ICUToken':
104 """ Create a ICUToken from the row of the word table.
106 count = 1 if row.info is None else row.info.get('count', 1)
107 addr_count = 1 if row.info is None else row.info.get('addr_count', 1)
109 penalty = base_penalty
112 elif row.type == 'W':
113 if len(row.word_token) == 1 and row.word_token == row.word:
114 penalty += 0.2 if row.word.isdigit() else 0.3
115 elif row.type == 'H':
116 penalty += sum(0.1 for c in row.word_token if c != ' ' and not c.isdigit())
117 if all(not c.isdigit() for c in row.word_token):
118 penalty += 0.2 * (len(row.word_token) - 1)
119 elif row.type == 'C':
120 if len(row.word_token) == 1:
124 lookup_word = row.word
126 lookup_word = row.info.get('lookup', row.word)
128 lookup_word = lookup_word.split('@', 1)[0]
130 lookup_word = row.word_token
132 return ICUToken(penalty=penalty, token=row.word_id, count=max(1, count),
133 lookup_word=lookup_word,
134 word_token=row.word_token, info=row.info,
135 addr_count=max(1, addr_count))
138 class ICUQueryAnalyzer(AbstractQueryAnalyzer):
139 """ Converter for query strings into a tokenized query
140 using the tokens created by a ICU tokenizer.
142 def __init__(self, conn: SearchConnection) -> None:
145 async def setup(self) -> None:
146 """ Set up static data structures needed for the analysis.
148 async def _make_normalizer() -> Any:
149 rules = await self.conn.get_property('tokenizer_import_normalisation')
150 return Transliterator.createFromRules("normalization", rules)
152 self.normalizer = await self.conn.get_cached_value('ICUTOK', 'normalizer',
155 async def _make_transliterator() -> Any:
156 rules = await self.conn.get_property('tokenizer_import_transliteration')
157 return Transliterator.createFromRules("transliteration", rules)
159 self.transliterator = await self.conn.get_cached_value('ICUTOK', 'transliterator',
160 _make_transliterator)
162 await self._setup_preprocessing()
164 if 'word' not in self.conn.t.meta.tables:
165 sa.Table('word', self.conn.t.meta,
166 sa.Column('word_id', sa.Integer),
167 sa.Column('word_token', sa.Text, nullable=False),
168 sa.Column('type', sa.Text, nullable=False),
169 sa.Column('word', sa.Text),
170 sa.Column('info', Json))
172 async def _setup_preprocessing(self) -> None:
173 """ Load the rules for preprocessing and set up the handlers.
176 rules = self.conn.config.load_sub_configuration('icu_tokenizer.yaml',
177 config='TOKENIZER_CONFIG')
178 preprocessing_rules = rules.get('query-preprocessing', [])
180 self.preprocessors = []
182 for func in preprocessing_rules:
183 if 'step' not in func:
184 raise UsageError("Preprocessing rule is missing the 'step' attribute.")
185 if not isinstance(func['step'], str):
186 raise UsageError("'step' attribute must be a simple string.")
188 module = self.conn.config.load_plugin_module(
189 func['step'], 'nominatim_api.query_preprocessing')
190 self.preprocessors.append(
191 module.create(QueryConfig(func).set_normalizer(self.normalizer)))
193 async def analyze_query(self, phrases: List[qmod.Phrase]) -> qmod.QueryStruct:
194 """ Analyze the given list of phrases and return the
197 log().section('Analyze query (using ICU tokenizer)')
198 for func in self.preprocessors:
199 phrases = func(phrases)
200 query = qmod.QueryStruct(phrases)
202 log().var_dump('Normalized query', query.source)
206 words = self.split_query(query)
207 log().var_dump('Transliterated query', lambda: query.get_transliterated_query())
209 for row in await self.lookup_in_db(list(words.keys())):
210 for trange in words[row.word_token]:
211 token = ICUToken.from_db_row(row, trange.penalty or 0.0)
213 if row.info['op'] in ('in', 'near'):
214 if trange.start == 0:
215 query.add_token(trange, qmod.TOKEN_NEAR_ITEM, token)
217 if trange.start == 0 and trange.end == query.num_token_slots():
218 query.add_token(trange, qmod.TOKEN_NEAR_ITEM, token)
220 query.add_token(trange, qmod.TOKEN_QUALIFIER, token)
222 query.add_token(trange, DB_TO_TOKEN_TYPE[row.type], token)
224 self.add_extra_tokens(query)
225 self.rerank_tokens(query)
227 log().table_dump('Word tokens', _dump_word_tokens(query))
231 def normalize_text(self, text: str) -> str:
232 """ Bring the given text into a normalized form. That is the
233 standardized form search will work with. All information removed
234 at this stage is inevitably lost.
236 return cast(str, self.normalizer.transliterate(text)).strip('-: ')
238 def split_query(self, query: qmod.QueryStruct) -> WordDict:
239 """ Transliterate the phrases and split them into tokens.
241 Returns a dictionary of words for lookup together
245 words: WordDict = defaultdict(list)
246 for phrase in query.source:
247 query.nodes[-1].ptype = phrase.ptype
248 phrase_split = re.split('([ :-])', phrase.text)
249 # The zip construct will give us the pairs of word/break from
250 # the regular expression split. As the split array ends on the
251 # final word, we simply use the fillvalue to even out the list and
252 # add the phrase break at the end.
253 for word, breakchar in zip_longest(*[iter(phrase_split)]*2, fillvalue=','):
256 trans = self.transliterator.transliterate(word)
258 for term in trans.split(' '):
260 query.add_node(qmod.BREAK_TOKEN, phrase.ptype,
261 PENALTY_IN_TOKEN_BREAK[qmod.BREAK_TOKEN],
263 query.nodes[-1].adjust_break(breakchar,
264 PENALTY_IN_TOKEN_BREAK[breakchar])
266 extract_words(query, phrase_start, words)
268 phrase_start = len(query.nodes)
269 query.nodes[-1].adjust_break(qmod.BREAK_END, PENALTY_IN_TOKEN_BREAK[qmod.BREAK_END])
273 async def lookup_in_db(self, words: List[str]) -> 'sa.Result[Any]':
274 """ Return the token information from the database for the
277 t = self.conn.t.meta.tables['word']
278 return await self.conn.execute(t.select().where(t.c.word_token.in_(words)))
280 def add_extra_tokens(self, query: qmod.QueryStruct) -> None:
281 """ Add tokens to query that are not saved in the database.
284 for i, node in enumerate(query.nodes):
285 is_full_token = node.btype not in (qmod.BREAK_TOKEN, qmod.BREAK_PART)
286 if need_hnr and is_full_token \
287 and len(node.term_normalized) <= 4 and node.term_normalized.isdigit():
288 query.add_token(qmod.TokenRange(i-1, i), qmod.TOKEN_HOUSENUMBER,
289 ICUToken(penalty=0.5, token=0,
290 count=1, addr_count=1,
291 lookup_word=node.term_lookup,
292 word_token=node.term_lookup, info=None))
294 need_hnr = is_full_token and not node.has_tokens(i+1, qmod.TOKEN_HOUSENUMBER)
296 def rerank_tokens(self, query: qmod.QueryStruct) -> None:
297 """ Add penalties to tokens that depend on presence of other token.
299 for i, node, tlist in query.iter_token_lists():
300 if tlist.ttype == qmod.TOKEN_POSTCODE:
301 for repl in node.starting:
302 if repl.end == tlist.end and repl.ttype != qmod.TOKEN_POSTCODE \
303 and (repl.ttype != qmod.TOKEN_HOUSENUMBER
304 or len(tlist.tokens[0].lookup_word) > 4):
305 repl.add_penalty(0.39)
306 elif (tlist.ttype == qmod.TOKEN_HOUSENUMBER
307 and len(tlist.tokens[0].lookup_word) <= 3):
308 if any(c.isdigit() for c in tlist.tokens[0].lookup_word):
309 for repl in node.starting:
310 if repl.end == tlist.end and repl.ttype != qmod.TOKEN_HOUSENUMBER:
311 repl.add_penalty(0.5 - tlist.tokens[0].penalty)
312 elif tlist.ttype not in (qmod.TOKEN_COUNTRY, qmod.TOKEN_PARTIAL):
313 norm = ' '.join(n.term_normalized for n in query.nodes[i + 1:tlist.end + 1]
314 if n.btype != qmod.BREAK_TOKEN)
316 # Can happen when the token only covers a partial term
317 norm = query.nodes[i + 1].term_normalized
318 for token in tlist.tokens:
319 cast(ICUToken, token).rematch(norm)
322 def _dump_word_tokens(query: qmod.QueryStruct) -> Iterator[List[Any]]:
323 yield ['type', 'token', 'word_token', 'lookup_word', 'penalty', 'count', 'info']
324 for node in query.nodes:
325 for tlist in node.starting:
326 for token in tlist.tokens:
327 t = cast(ICUToken, token)
328 yield [tlist.ttype, t.token, t.word_token or '',
329 t.lookup_word or '', t.penalty, t.count, t.info]
332 async def create_query_analyzer(conn: SearchConnection) -> AbstractQueryAnalyzer:
333 """ Create and set up a new query analyzer for a database based
334 on the ICU tokenizer.
336 out = ICUQueryAnalyzer(conn)