1 # SPDX-License-Identifier: GPL-3.0-or-later
3 # This file is part of Nominatim. (https://nominatim.org)
5 # Copyright (C) 2024 by the Nominatim developer community.
6 # For a full list of authors see the git log.
8 Implementation of query analysis for the ICU tokenizer.
10 from typing import Tuple, Dict, List, Optional, Iterator, Any, cast
11 from collections import defaultdict
15 from itertools import zip_longest
17 from icu import Transliterator
19 import sqlalchemy as sa
21 from ..errors import UsageError
22 from ..typing import SaRow
23 from ..sql.sqlalchemy_types import Json
24 from ..connection import SearchConnection
25 from ..logging import log
26 from . import query as qmod
27 from ..query_preprocessing.config import QueryConfig
28 from .query_analyzer_factory import AbstractQueryAnalyzer
32 'W': qmod.TokenType.WORD,
33 'w': qmod.TokenType.PARTIAL,
34 'H': qmod.TokenType.HOUSENUMBER,
35 'P': qmod.TokenType.POSTCODE,
36 'C': qmod.TokenType.COUNTRY
39 PENALTY_IN_TOKEN_BREAK = {
40 qmod.BreakType.START: 0.5,
41 qmod.BreakType.END: 0.5,
42 qmod.BreakType.PHRASE: 0.5,
43 qmod.BreakType.SOFT_PHRASE: 0.5,
44 qmod.BreakType.WORD: 0.1,
45 qmod.BreakType.PART: 0.0,
46 qmod.BreakType.TOKEN: 0.0
50 @dataclasses.dataclass
52 """ Normalized and transliterated form of a single term in the query.
53 When the term came out of a split during the transliteration,
54 the normalized string is the full word before transliteration.
55 The word number keeps track of the word before transliteration
56 and can be used to identify partial transliterated terms.
57 Penalty is the break penalty for the break following the token.
65 QueryParts = List[QueryPart]
66 WordDict = Dict[str, List[qmod.TokenRange]]
69 def yield_words(terms: List[QueryPart], start: int) -> Iterator[Tuple[str, qmod.TokenRange]]:
70 """ Return all combinations of words in the terms list after the
74 for first in range(start, total):
75 word = terms[first].token
76 penalty = PENALTY_IN_TOKEN_BREAK[qmod.BreakType.WORD]
77 yield word, qmod.TokenRange(first, first + 1, penalty=penalty)
78 for last in range(first + 1, min(first + 20, total)):
79 word = ' '.join((word, terms[last].token))
80 penalty += terms[last - 1].penalty
81 yield word, qmod.TokenRange(first, last + 1, penalty=penalty)
84 @dataclasses.dataclass
85 class ICUToken(qmod.Token):
86 """ Specialised token for ICU tokenizer.
89 info: Optional[Dict[str, Any]]
91 def get_category(self) -> Tuple[str, str]:
93 return self.info.get('class', ''), self.info.get('type', '')
95 def rematch(self, norm: str) -> None:
96 """ Check how well the token matches the given normalized string
97 and add a penalty, if necessary.
99 if not self.lookup_word:
102 seq = difflib.SequenceMatcher(a=self.lookup_word, b=norm)
104 for tag, afrom, ato, bfrom, bto in seq.get_opcodes():
105 if tag in ('delete', 'insert') and (afrom == 0 or ato == len(self.lookup_word)):
107 elif tag == 'replace':
108 distance += max((ato-afrom), (bto-bfrom))
110 distance += abs((ato-afrom) - (bto-bfrom))
111 self.penalty += (distance/len(self.lookup_word))
114 def from_db_row(row: SaRow, base_penalty: float = 0.0) -> 'ICUToken':
115 """ Create a ICUToken from the row of the word table.
117 count = 1 if row.info is None else row.info.get('count', 1)
118 addr_count = 1 if row.info is None else row.info.get('addr_count', 1)
120 penalty = base_penalty
123 elif row.type == 'W':
124 if len(row.word_token) == 1 and row.word_token == row.word:
125 penalty += 0.2 if row.word.isdigit() else 0.3
126 elif row.type == 'H':
127 penalty += sum(0.1 for c in row.word_token if c != ' ' and not c.isdigit())
128 if all(not c.isdigit() for c in row.word_token):
129 penalty += 0.2 * (len(row.word_token) - 1)
130 elif row.type == 'C':
131 if len(row.word_token) == 1:
135 lookup_word = row.word
137 lookup_word = row.info.get('lookup', row.word)
139 lookup_word = lookup_word.split('@', 1)[0]
141 lookup_word = row.word_token
143 return ICUToken(penalty=penalty, token=row.word_id, count=max(1, count),
144 lookup_word=lookup_word,
145 word_token=row.word_token, info=row.info,
146 addr_count=max(1, addr_count))
149 class ICUQueryAnalyzer(AbstractQueryAnalyzer):
150 """ Converter for query strings into a tokenized query
151 using the tokens created by a ICU tokenizer.
153 def __init__(self, conn: SearchConnection) -> None:
156 async def setup(self) -> None:
157 """ Set up static data structures needed for the analysis.
159 async def _make_normalizer() -> Any:
160 rules = await self.conn.get_property('tokenizer_import_normalisation')
161 return Transliterator.createFromRules("normalization", rules)
163 self.normalizer = await self.conn.get_cached_value('ICUTOK', 'normalizer',
166 async def _make_transliterator() -> Any:
167 rules = await self.conn.get_property('tokenizer_import_transliteration')
168 return Transliterator.createFromRules("transliteration", rules)
170 self.transliterator = await self.conn.get_cached_value('ICUTOK', 'transliterator',
171 _make_transliterator)
173 await self._setup_preprocessing()
175 if 'word' not in self.conn.t.meta.tables:
176 sa.Table('word', self.conn.t.meta,
177 sa.Column('word_id', sa.Integer),
178 sa.Column('word_token', sa.Text, nullable=False),
179 sa.Column('type', sa.Text, nullable=False),
180 sa.Column('word', sa.Text),
181 sa.Column('info', Json))
183 async def _setup_preprocessing(self) -> None:
184 """ Load the rules for preprocessing and set up the handlers.
187 rules = self.conn.config.load_sub_configuration('icu_tokenizer.yaml',
188 config='TOKENIZER_CONFIG')
189 preprocessing_rules = rules.get('query-preprocessing', [])
191 self.preprocessors = []
193 for func in preprocessing_rules:
194 if 'step' not in func:
195 raise UsageError("Preprocessing rule is missing the 'step' attribute.")
196 if not isinstance(func['step'], str):
197 raise UsageError("'step' attribute must be a simple string.")
199 module = self.conn.config.load_plugin_module(
200 func['step'], 'nominatim_api.query_preprocessing')
201 self.preprocessors.append(
202 module.create(QueryConfig(func).set_normalizer(self.normalizer)))
204 async def analyze_query(self, phrases: List[qmod.Phrase]) -> qmod.QueryStruct:
205 """ Analyze the given list of phrases and return the
208 log().section('Analyze query (using ICU tokenizer)')
209 for func in self.preprocessors:
210 phrases = func(phrases)
212 if len(phrases) == 1 \
213 and phrases[0].text.count(' ') > 3 \
214 and max(len(s) for s in phrases[0].text.split()) < 3:
217 query = qmod.QueryStruct(phrases)
219 log().var_dump('Normalized query', query.source)
223 parts, words = self.split_query(query)
224 log().var_dump('Transliterated query', lambda: _dump_transliterated(query, parts))
226 for row in await self.lookup_in_db(list(words.keys())):
227 for trange in words[row.word_token]:
228 token = ICUToken.from_db_row(row, trange.penalty or 0.0)
230 if row.info['op'] in ('in', 'near'):
231 if trange.start == 0:
232 query.add_token(trange, qmod.TokenType.NEAR_ITEM, token)
234 if trange.start == 0 and trange.end == query.num_token_slots():
235 query.add_token(trange, qmod.TokenType.NEAR_ITEM, token)
237 query.add_token(trange, qmod.TokenType.QUALIFIER, token)
239 query.add_token(trange, DB_TO_TOKEN_TYPE[row.type], token)
241 self.add_extra_tokens(query, parts)
242 self.rerank_tokens(query, parts)
244 log().table_dump('Word tokens', _dump_word_tokens(query))
248 def normalize_text(self, text: str) -> str:
249 """ Bring the given text into a normalized form. That is the
250 standardized form search will work with. All information removed
251 at this stage is inevitably lost.
253 return cast(str, self.normalizer.transliterate(text))
255 def split_query(self, query: qmod.QueryStruct) -> Tuple[QueryParts, WordDict]:
256 """ Transliterate the phrases and split them into tokens.
258 Returns the list of transliterated tokens together with their
259 normalized form and a dictionary of words for lookup together
262 parts: QueryParts = []
264 words = defaultdict(list)
266 for phrase in query.source:
267 query.nodes[-1].ptype = phrase.ptype
268 phrase_split = re.split('([ :-])', phrase.text)
269 # The zip construct will give us the pairs of word/break from
270 # the regular expression split. As the split array ends on the
271 # final word, we simply use the fillvalue to even out the list and
272 # add the phrase break at the end.
273 for word, breakchar in zip_longest(*[iter(phrase_split)]*2, fillvalue=','):
276 trans = self.transliterator.transliterate(word)
278 for term in trans.split(' '):
280 parts.append(QueryPart(term, word, wordnr,
281 PENALTY_IN_TOKEN_BREAK[qmod.BreakType.TOKEN]))
282 query.add_node(qmod.BreakType.TOKEN, phrase.ptype)
283 query.nodes[-1].btype = qmod.BreakType(breakchar)
284 parts[-1].penalty = PENALTY_IN_TOKEN_BREAK[qmod.BreakType(breakchar)]
287 for word, wrange in yield_words(parts, phrase_start):
288 words[word].append(wrange)
290 phrase_start = len(parts)
291 query.nodes[-1].btype = qmod.BreakType.END
295 async def lookup_in_db(self, words: List[str]) -> 'sa.Result[Any]':
296 """ Return the token information from the database for the
299 t = self.conn.t.meta.tables['word']
300 return await self.conn.execute(t.select().where(t.c.word_token.in_(words)))
302 def add_extra_tokens(self, query: qmod.QueryStruct, parts: QueryParts) -> None:
303 """ Add tokens to query that are not saved in the database.
305 for part, node, i in zip(parts, query.nodes, range(1000)):
306 if len(part.token) <= 4 and part.token.isdigit()\
307 and not node.has_tokens(i+1, qmod.TokenType.HOUSENUMBER):
308 query.add_token(qmod.TokenRange(i, i+1), qmod.TokenType.HOUSENUMBER,
309 ICUToken(penalty=0.5, token=0,
310 count=1, addr_count=1, lookup_word=part.token,
311 word_token=part.token, info=None))
313 def rerank_tokens(self, query: qmod.QueryStruct, parts: QueryParts) -> None:
314 """ Add penalties to tokens that depend on presence of other token.
316 for i, node, tlist in query.iter_token_lists():
317 if tlist.ttype == qmod.TokenType.POSTCODE:
318 for repl in node.starting:
319 if repl.end == tlist.end and repl.ttype != qmod.TokenType.POSTCODE \
320 and (repl.ttype != qmod.TokenType.HOUSENUMBER
321 or len(tlist.tokens[0].lookup_word) > 4):
322 repl.add_penalty(0.39)
323 elif (tlist.ttype == qmod.TokenType.HOUSENUMBER
324 and len(tlist.tokens[0].lookup_word) <= 3):
325 if any(c.isdigit() for c in tlist.tokens[0].lookup_word):
326 for repl in node.starting:
327 if repl.end == tlist.end and repl.ttype != qmod.TokenType.HOUSENUMBER:
328 repl.add_penalty(0.5 - tlist.tokens[0].penalty)
329 elif tlist.ttype not in (qmod.TokenType.COUNTRY, qmod.TokenType.PARTIAL):
330 norm = parts[i].normalized
331 for j in range(i + 1, tlist.end):
332 if parts[j - 1].word_number != parts[j].word_number:
333 norm += ' ' + parts[j].normalized
334 for token in tlist.tokens:
335 cast(ICUToken, token).rematch(norm)
338 def _dump_transliterated(query: qmod.QueryStruct, parts: QueryParts) -> str:
339 out = query.nodes[0].btype.value
340 for node, part in zip(query.nodes[1:], parts):
341 out += part.token + node.btype.value
345 def _dump_word_tokens(query: qmod.QueryStruct) -> Iterator[List[Any]]:
346 yield ['type', 'token', 'word_token', 'lookup_word', 'penalty', 'count', 'info']
347 for node in query.nodes:
348 for tlist in node.starting:
349 for token in tlist.tokens:
350 t = cast(ICUToken, token)
351 yield [tlist.ttype.name, t.token, t.word_token or '',
352 t.lookup_word or '', t.penalty, t.count, t.info]
355 async def create_query_analyzer(conn: SearchConnection) -> AbstractQueryAnalyzer:
356 """ Create and set up a new query analyzer for a database based
357 on the ICU tokenizer.
359 out = ICUQueryAnalyzer(conn)