1 # SPDX-License-Identifier: GPL-3.0-or-later
3 # This file is part of Nominatim. (https://nominatim.org)
5 # Copyright (C) 2024 by the Nominatim developer community.
6 # For a full list of authors see the git log.
8 Implementation of query analysis for the ICU tokenizer.
10 from typing import Tuple, Dict, List, Optional, NamedTuple, Iterator, Any, cast
11 from collections import defaultdict
15 from itertools import zip_longest
17 from icu import Transliterator
19 import sqlalchemy as sa
21 from ..errors import UsageError
22 from ..typing import SaRow
23 from ..sql.sqlalchemy_types import Json
24 from ..connection import SearchConnection
25 from ..logging import log
26 from . import query as qmod
27 from ..query_preprocessing.config import QueryConfig
28 from .query_analyzer_factory import AbstractQueryAnalyzer
32 'W': qmod.TokenType.WORD,
33 'w': qmod.TokenType.PARTIAL,
34 'H': qmod.TokenType.HOUSENUMBER,
35 'P': qmod.TokenType.POSTCODE,
36 'C': qmod.TokenType.COUNTRY
40 class QueryPart(NamedTuple):
41 """ Normalized and transliterated form of a single term in the query.
42 When the term came out of a split during the transliteration,
43 the normalized string is the full word before transliteration.
44 The word number keeps track of the word before transliteration
45 and can be used to identify partial transliterated terms.
52 QueryParts = List[QueryPart]
53 WordDict = Dict[str, List[qmod.TokenRange]]
56 def yield_words(terms: List[QueryPart], start: int) -> Iterator[Tuple[str, qmod.TokenRange]]:
57 """ Return all combinations of words in the terms list after the
61 for first in range(start, total):
62 word = terms[first].token
63 yield word, qmod.TokenRange(first, first + 1)
64 for last in range(first + 1, min(first + 20, total)):
65 word = ' '.join((word, terms[last].token))
66 yield word, qmod.TokenRange(first, last + 1)
69 @dataclasses.dataclass
70 class ICUToken(qmod.Token):
71 """ Specialised token for ICU tokenizer.
74 info: Optional[Dict[str, Any]]
76 def get_category(self) -> Tuple[str, str]:
78 return self.info.get('class', ''), self.info.get('type', '')
80 def rematch(self, norm: str) -> None:
81 """ Check how well the token matches the given normalized string
82 and add a penalty, if necessary.
84 if not self.lookup_word:
87 seq = difflib.SequenceMatcher(a=self.lookup_word, b=norm)
89 for tag, afrom, ato, bfrom, bto in seq.get_opcodes():
90 if tag in ('delete', 'insert') and (afrom == 0 or ato == len(self.lookup_word)):
92 elif tag == 'replace':
93 distance += max((ato-afrom), (bto-bfrom))
95 distance += abs((ato-afrom) - (bto-bfrom))
96 self.penalty += (distance/len(self.lookup_word))
99 def from_db_row(row: SaRow) -> 'ICUToken':
100 """ Create a ICUToken from the row of the word table.
102 count = 1 if row.info is None else row.info.get('count', 1)
103 addr_count = 1 if row.info is None else row.info.get('addr_count', 1)
108 elif row.type == 'W':
109 if len(row.word_token) == 1 and row.word_token == row.word:
110 penalty = 0.2 if row.word.isdigit() else 0.3
111 elif row.type == 'H':
112 penalty = sum(0.1 for c in row.word_token if c != ' ' and not c.isdigit())
113 if all(not c.isdigit() for c in row.word_token):
114 penalty += 0.2 * (len(row.word_token) - 1)
115 elif row.type == 'C':
116 if len(row.word_token) == 1:
120 lookup_word = row.word
122 lookup_word = row.info.get('lookup', row.word)
124 lookup_word = lookup_word.split('@', 1)[0]
126 lookup_word = row.word_token
128 return ICUToken(penalty=penalty, token=row.word_id, count=max(1, count),
129 lookup_word=lookup_word,
130 word_token=row.word_token, info=row.info,
131 addr_count=max(1, addr_count))
134 class ICUQueryAnalyzer(AbstractQueryAnalyzer):
135 """ Converter for query strings into a tokenized query
136 using the tokens created by a ICU tokenizer.
138 def __init__(self, conn: SearchConnection) -> None:
141 async def setup(self) -> None:
142 """ Set up static data structures needed for the analysis.
144 async def _make_normalizer() -> Any:
145 rules = await self.conn.get_property('tokenizer_import_normalisation')
146 return Transliterator.createFromRules("normalization", rules)
148 self.normalizer = await self.conn.get_cached_value('ICUTOK', 'normalizer',
151 async def _make_transliterator() -> Any:
152 rules = await self.conn.get_property('tokenizer_import_transliteration')
153 return Transliterator.createFromRules("transliteration", rules)
155 self.transliterator = await self.conn.get_cached_value('ICUTOK', 'transliterator',
156 _make_transliterator)
158 await self._setup_preprocessing()
160 if 'word' not in self.conn.t.meta.tables:
161 sa.Table('word', self.conn.t.meta,
162 sa.Column('word_id', sa.Integer),
163 sa.Column('word_token', sa.Text, nullable=False),
164 sa.Column('type', sa.Text, nullable=False),
165 sa.Column('word', sa.Text),
166 sa.Column('info', Json))
168 async def _setup_preprocessing(self) -> None:
169 """ Load the rules for preprocessing and set up the handlers.
172 rules = self.conn.config.load_sub_configuration('icu_tokenizer.yaml',
173 config='TOKENIZER_CONFIG')
174 preprocessing_rules = rules.get('query-preprocessing', [])
176 self.preprocessors = []
178 for func in preprocessing_rules:
179 if 'step' not in func:
180 raise UsageError("Preprocessing rule is missing the 'step' attribute.")
181 if not isinstance(func['step'], str):
182 raise UsageError("'step' attribute must be a simple string.")
184 module = self.conn.config.load_plugin_module(
185 func['step'], 'nominatim_api.query_preprocessing')
186 self.preprocessors.append(
187 module.create(QueryConfig(func).set_normalizer(self.normalizer)))
189 async def analyze_query(self, phrases: List[qmod.Phrase]) -> qmod.QueryStruct:
190 """ Analyze the given list of phrases and return the
193 log().section('Analyze query (using ICU tokenizer)')
194 for func in self.preprocessors:
195 phrases = func(phrases)
196 query = qmod.QueryStruct(phrases)
198 log().var_dump('Normalized query', query.source)
202 parts, words = self.split_query(query)
203 log().var_dump('Transliterated query', lambda: _dump_transliterated(query, parts))
205 for row in await self.lookup_in_db(list(words.keys())):
206 for trange in words[row.word_token]:
207 token = ICUToken.from_db_row(row)
209 if row.info['op'] in ('in', 'near'):
210 if trange.start == 0:
211 query.add_token(trange, qmod.TokenType.NEAR_ITEM, token)
213 if trange.start == 0 and trange.end == query.num_token_slots():
214 query.add_token(trange, qmod.TokenType.NEAR_ITEM, token)
216 query.add_token(trange, qmod.TokenType.QUALIFIER, token)
218 query.add_token(trange, DB_TO_TOKEN_TYPE[row.type], token)
220 self.add_extra_tokens(query, parts)
221 self.rerank_tokens(query, parts)
223 log().table_dump('Word tokens', _dump_word_tokens(query))
227 def normalize_text(self, text: str) -> str:
228 """ Bring the given text into a normalized form. That is the
229 standardized form search will work with. All information removed
230 at this stage is inevitably lost.
232 return cast(str, self.normalizer.transliterate(text))
234 def split_query(self, query: qmod.QueryStruct) -> Tuple[QueryParts, WordDict]:
235 """ Transliterate the phrases and split them into tokens.
237 Returns the list of transliterated tokens together with their
238 normalized form and a dictionary of words for lookup together
241 parts: QueryParts = []
243 words = defaultdict(list)
245 for phrase in query.source:
246 query.nodes[-1].ptype = phrase.ptype
247 phrase_split = re.split('([ :-])', phrase.text)
248 # The zip construct will give us the pairs of word/break from
249 # the regular expression split. As the split array ends on the
250 # final word, we simply use the fillvalue to even out the list and
251 # add the phrase break at the end.
252 for word, breakchar in zip_longest(*[iter(phrase_split)]*2, fillvalue=','):
255 trans = self.transliterator.transliterate(word)
257 for term in trans.split(' '):
259 parts.append(QueryPart(term, word, wordnr))
260 query.add_node(qmod.BreakType.TOKEN, phrase.ptype)
261 query.nodes[-1].btype = qmod.BreakType(breakchar)
264 for word, wrange in yield_words(parts, phrase_start):
265 words[word].append(wrange)
267 phrase_start = len(parts)
268 query.nodes[-1].btype = qmod.BreakType.END
272 async def lookup_in_db(self, words: List[str]) -> 'sa.Result[Any]':
273 """ Return the token information from the database for the
276 t = self.conn.t.meta.tables['word']
277 return await self.conn.execute(t.select().where(t.c.word_token.in_(words)))
279 def add_extra_tokens(self, query: qmod.QueryStruct, parts: QueryParts) -> None:
280 """ Add tokens to query that are not saved in the database.
282 for part, node, i in zip(parts, query.nodes, range(1000)):
283 if len(part.token) <= 4 and part[0].isdigit()\
284 and not node.has_tokens(i+1, qmod.TokenType.HOUSENUMBER):
285 query.add_token(qmod.TokenRange(i, i+1), qmod.TokenType.HOUSENUMBER,
286 ICUToken(penalty=0.5, token=0,
287 count=1, addr_count=1, lookup_word=part.token,
288 word_token=part.token, info=None))
290 def rerank_tokens(self, query: qmod.QueryStruct, parts: QueryParts) -> None:
291 """ Add penalties to tokens that depend on presence of other token.
293 for i, node, tlist in query.iter_token_lists():
294 if tlist.ttype == qmod.TokenType.POSTCODE:
295 for repl in node.starting:
296 if repl.end == tlist.end and repl.ttype != qmod.TokenType.POSTCODE \
297 and (repl.ttype != qmod.TokenType.HOUSENUMBER
298 or len(tlist.tokens[0].lookup_word) > 4):
299 repl.add_penalty(0.39)
300 elif (tlist.ttype == qmod.TokenType.HOUSENUMBER
301 and len(tlist.tokens[0].lookup_word) <= 3):
302 if any(c.isdigit() for c in tlist.tokens[0].lookup_word):
303 for repl in node.starting:
304 if repl.end == tlist.end and repl.ttype != qmod.TokenType.HOUSENUMBER:
305 repl.add_penalty(0.5 - tlist.tokens[0].penalty)
306 elif tlist.ttype not in (qmod.TokenType.COUNTRY, qmod.TokenType.PARTIAL):
307 norm = parts[i].normalized
308 for j in range(i + 1, tlist.end):
309 if parts[j - 1].word_number != parts[j].word_number:
310 norm += ' ' + parts[j].normalized
311 for token in tlist.tokens:
312 cast(ICUToken, token).rematch(norm)
315 def _dump_transliterated(query: qmod.QueryStruct, parts: QueryParts) -> str:
316 out = query.nodes[0].btype.value
317 for node, part in zip(query.nodes[1:], parts):
318 out += part.token + node.btype.value
322 def _dump_word_tokens(query: qmod.QueryStruct) -> Iterator[List[Any]]:
323 yield ['type', 'token', 'word_token', 'lookup_word', 'penalty', 'count', 'info']
324 for node in query.nodes:
325 for tlist in node.starting:
326 for token in tlist.tokens:
327 t = cast(ICUToken, token)
328 yield [tlist.ttype.name, t.token, t.word_token or '',
329 t.lookup_word or '', t.penalty, t.count, t.info]
332 async def create_query_analyzer(conn: SearchConnection) -> AbstractQueryAnalyzer:
333 """ Create and set up a new query analyzer for a database based
334 on the ICU tokenizer.
336 out = ICUQueryAnalyzer(conn)