]> git.openstreetmap.org Git - nominatim.git/blob - src/nominatim_api/search/icu_tokenizer.py
fix style issue found by flake8
[nominatim.git] / src / nominatim_api / search / icu_tokenizer.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2024 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Implementation of query analysis for the ICU tokenizer.
9 """
10 from typing import Tuple, Dict, List, Optional, NamedTuple, Iterator, Any, cast
11 from collections import defaultdict
12 import dataclasses
13 import difflib
14
15 from icu import Transliterator
16
17 import sqlalchemy as sa
18
19 from ..typing import SaRow
20 from ..sql.sqlalchemy_types import Json
21 from ..connection import SearchConnection
22 from ..logging import log
23 from ..search import query as qmod
24 from ..search.query_analyzer_factory import AbstractQueryAnalyzer
25
26
27 DB_TO_TOKEN_TYPE = {
28     'W': qmod.TokenType.WORD,
29     'w': qmod.TokenType.PARTIAL,
30     'H': qmod.TokenType.HOUSENUMBER,
31     'P': qmod.TokenType.POSTCODE,
32     'C': qmod.TokenType.COUNTRY
33 }
34
35
36 class QueryPart(NamedTuple):
37     """ Normalized and transliterated form of a single term in the query.
38         When the term came out of a split during the transliteration,
39         the normalized string is the full word before transliteration.
40         The word number keeps track of the word before transliteration
41         and can be used to identify partial transliterated terms.
42     """
43     token: str
44     normalized: str
45     word_number: int
46
47
48 QueryParts = List[QueryPart]
49 WordDict = Dict[str, List[qmod.TokenRange]]
50
51
52 def yield_words(terms: List[QueryPart], start: int) -> Iterator[Tuple[str, qmod.TokenRange]]:
53     """ Return all combinations of words in the terms list after the
54         given position.
55     """
56     total = len(terms)
57     for first in range(start, total):
58         word = terms[first].token
59         yield word, qmod.TokenRange(first, first + 1)
60         for last in range(first + 1, min(first + 20, total)):
61             word = ' '.join((word, terms[last].token))
62             yield word, qmod.TokenRange(first, last + 1)
63
64
65 @dataclasses.dataclass
66 class ICUToken(qmod.Token):
67     """ Specialised token for ICU tokenizer.
68     """
69     word_token: str
70     info: Optional[Dict[str, Any]]
71
72     def get_category(self) -> Tuple[str, str]:
73         assert self.info
74         return self.info.get('class', ''), self.info.get('type', '')
75
76     def rematch(self, norm: str) -> None:
77         """ Check how well the token matches the given normalized string
78             and add a penalty, if necessary.
79         """
80         if not self.lookup_word:
81             return
82
83         seq = difflib.SequenceMatcher(a=self.lookup_word, b=norm)
84         distance = 0
85         for tag, afrom, ato, bfrom, bto in seq.get_opcodes():
86             if tag in ('delete', 'insert') and (afrom == 0 or ato == len(self.lookup_word)):
87                 distance += 1
88             elif tag == 'replace':
89                 distance += max((ato-afrom), (bto-bfrom))
90             elif tag != 'equal':
91                 distance += abs((ato-afrom) - (bto-bfrom))
92         self.penalty += (distance/len(self.lookup_word))
93
94     @staticmethod
95     def from_db_row(row: SaRow) -> 'ICUToken':
96         """ Create a ICUToken from the row of the word table.
97         """
98         count = 1 if row.info is None else row.info.get('count', 1)
99         addr_count = 1 if row.info is None else row.info.get('addr_count', 1)
100
101         penalty = 0.0
102         if row.type == 'w':
103             penalty = 0.3
104         elif row.type == 'W':
105             if len(row.word_token) == 1 and row.word_token == row.word:
106                 penalty = 0.2 if row.word.isdigit() else 0.3
107         elif row.type == 'H':
108             penalty = sum(0.1 for c in row.word_token if c != ' ' and not c.isdigit())
109             if all(not c.isdigit() for c in row.word_token):
110                 penalty += 0.2 * (len(row.word_token) - 1)
111         elif row.type == 'C':
112             if len(row.word_token) == 1:
113                 penalty = 0.3
114
115         if row.info is None:
116             lookup_word = row.word
117         else:
118             lookup_word = row.info.get('lookup', row.word)
119         if lookup_word:
120             lookup_word = lookup_word.split('@', 1)[0]
121         else:
122             lookup_word = row.word_token
123
124         return ICUToken(penalty=penalty, token=row.word_id, count=max(1, count),
125                         lookup_word=lookup_word,
126                         word_token=row.word_token, info=row.info,
127                         addr_count=max(1, addr_count))
128
129
130 class ICUQueryAnalyzer(AbstractQueryAnalyzer):
131     """ Converter for query strings into a tokenized query
132         using the tokens created by a ICU tokenizer.
133     """
134     def __init__(self, conn: SearchConnection) -> None:
135         self.conn = conn
136
137     async def setup(self) -> None:
138         """ Set up static data structures needed for the analysis.
139         """
140         async def _make_normalizer() -> Any:
141             rules = await self.conn.get_property('tokenizer_import_normalisation')
142             return Transliterator.createFromRules("normalization", rules)
143
144         self.normalizer = await self.conn.get_cached_value('ICUTOK', 'normalizer',
145                                                            _make_normalizer)
146
147         async def _make_transliterator() -> Any:
148             rules = await self.conn.get_property('tokenizer_import_transliteration')
149             return Transliterator.createFromRules("transliteration", rules)
150
151         self.transliterator = await self.conn.get_cached_value('ICUTOK', 'transliterator',
152                                                                _make_transliterator)
153
154         if 'word' not in self.conn.t.meta.tables:
155             sa.Table('word', self.conn.t.meta,
156                      sa.Column('word_id', sa.Integer),
157                      sa.Column('word_token', sa.Text, nullable=False),
158                      sa.Column('type', sa.Text, nullable=False),
159                      sa.Column('word', sa.Text),
160                      sa.Column('info', Json))
161
162     async def analyze_query(self, phrases: List[qmod.Phrase]) -> qmod.QueryStruct:
163         """ Analyze the given list of phrases and return the
164             tokenized query.
165         """
166         log().section('Analyze query (using ICU tokenizer)')
167         normalized = list(filter(lambda p: p.text,
168                                  (qmod.Phrase(p.ptype, self.normalize_text(p.text))
169                                   for p in phrases)))
170         query = qmod.QueryStruct(normalized)
171         log().var_dump('Normalized query', query.source)
172         if not query.source:
173             return query
174
175         parts, words = self.split_query(query)
176         log().var_dump('Transliterated query', lambda: _dump_transliterated(query, parts))
177
178         for row in await self.lookup_in_db(list(words.keys())):
179             for trange in words[row.word_token]:
180                 token = ICUToken.from_db_row(row)
181                 if row.type == 'S':
182                     if row.info['op'] in ('in', 'near'):
183                         if trange.start == 0:
184                             query.add_token(trange, qmod.TokenType.NEAR_ITEM, token)
185                     else:
186                         if trange.start == 0 and trange.end == query.num_token_slots():
187                             query.add_token(trange, qmod.TokenType.NEAR_ITEM, token)
188                         else:
189                             query.add_token(trange, qmod.TokenType.QUALIFIER, token)
190                 else:
191                     query.add_token(trange, DB_TO_TOKEN_TYPE[row.type], token)
192
193         self.add_extra_tokens(query, parts)
194         self.rerank_tokens(query, parts)
195
196         log().table_dump('Word tokens', _dump_word_tokens(query))
197
198         return query
199
200     def normalize_text(self, text: str) -> str:
201         """ Bring the given text into a normalized form. That is the
202             standardized form search will work with. All information removed
203             at this stage is inevitably lost.
204         """
205         return cast(str, self.normalizer.transliterate(text))
206
207     def split_query(self, query: qmod.QueryStruct) -> Tuple[QueryParts, WordDict]:
208         """ Transliterate the phrases and split them into tokens.
209
210             Returns the list of transliterated tokens together with their
211             normalized form and a dictionary of words for lookup together
212             with their position.
213         """
214         parts: QueryParts = []
215         phrase_start = 0
216         words = defaultdict(list)
217         wordnr = 0
218         for phrase in query.source:
219             query.nodes[-1].ptype = phrase.ptype
220             for word in phrase.text.split(' '):
221                 trans = self.transliterator.transliterate(word)
222                 if trans:
223                     for term in trans.split(' '):
224                         if term:
225                             parts.append(QueryPart(term, word, wordnr))
226                             query.add_node(qmod.BreakType.TOKEN, phrase.ptype)
227                     query.nodes[-1].btype = qmod.BreakType.WORD
228                 wordnr += 1
229             query.nodes[-1].btype = qmod.BreakType.PHRASE
230
231             for word, wrange in yield_words(parts, phrase_start):
232                 words[word].append(wrange)
233
234             phrase_start = len(parts)
235         query.nodes[-1].btype = qmod.BreakType.END
236
237         return parts, words
238
239     async def lookup_in_db(self, words: List[str]) -> 'sa.Result[Any]':
240         """ Return the token information from the database for the
241             given word tokens.
242         """
243         t = self.conn.t.meta.tables['word']
244         return await self.conn.execute(t.select().where(t.c.word_token.in_(words)))
245
246     def add_extra_tokens(self, query: qmod.QueryStruct, parts: QueryParts) -> None:
247         """ Add tokens to query that are not saved in the database.
248         """
249         for part, node, i in zip(parts, query.nodes, range(1000)):
250             if len(part.token) <= 4 and part[0].isdigit()\
251                and not node.has_tokens(i+1, qmod.TokenType.HOUSENUMBER):
252                 query.add_token(qmod.TokenRange(i, i+1), qmod.TokenType.HOUSENUMBER,
253                                 ICUToken(penalty=0.5, token=0,
254                                          count=1, addr_count=1, lookup_word=part.token,
255                                          word_token=part.token, info=None))
256
257     def rerank_tokens(self, query: qmod.QueryStruct, parts: QueryParts) -> None:
258         """ Add penalties to tokens that depend on presence of other token.
259         """
260         for i, node, tlist in query.iter_token_lists():
261             if tlist.ttype == qmod.TokenType.POSTCODE:
262                 for repl in node.starting:
263                     if repl.end == tlist.end and repl.ttype != qmod.TokenType.POSTCODE \
264                        and (repl.ttype != qmod.TokenType.HOUSENUMBER
265                             or len(tlist.tokens[0].lookup_word) > 4):
266                         repl.add_penalty(0.39)
267             elif (tlist.ttype == qmod.TokenType.HOUSENUMBER
268                   and len(tlist.tokens[0].lookup_word) <= 3):
269                 if any(c.isdigit() for c in tlist.tokens[0].lookup_word):
270                     for repl in node.starting:
271                         if repl.end == tlist.end and repl.ttype != qmod.TokenType.HOUSENUMBER:
272                             repl.add_penalty(0.5 - tlist.tokens[0].penalty)
273             elif tlist.ttype not in (qmod.TokenType.COUNTRY, qmod.TokenType.PARTIAL):
274                 norm = parts[i].normalized
275                 for j in range(i + 1, tlist.end):
276                     if parts[j - 1].word_number != parts[j].word_number:
277                         norm += '  ' + parts[j].normalized
278                 for token in tlist.tokens:
279                     cast(ICUToken, token).rematch(norm)
280
281
282 def _dump_transliterated(query: qmod.QueryStruct, parts: QueryParts) -> str:
283     out = query.nodes[0].btype.value
284     for node, part in zip(query.nodes[1:], parts):
285         out += part.token + node.btype.value
286     return out
287
288
289 def _dump_word_tokens(query: qmod.QueryStruct) -> Iterator[List[Any]]:
290     yield ['type', 'token', 'word_token', 'lookup_word', 'penalty', 'count', 'info']
291     for node in query.nodes:
292         for tlist in node.starting:
293             for token in tlist.tokens:
294                 t = cast(ICUToken, token)
295                 yield [tlist.ttype.name, t.token, t.word_token or '',
296                        t.lookup_word or '', t.penalty, t.count, t.info]
297
298
299 async def create_query_analyzer(conn: SearchConnection) -> AbstractQueryAnalyzer:
300     """ Create and set up a new query analyzer for a database based
301         on the ICU tokenizer.
302     """
303     out = ICUQueryAnalyzer(conn)
304     await out.setup()
305
306     return out