]> git.openstreetmap.org Git - nominatim.git/blob - nominatim/api/search/icu_tokenizer.py
query analyzer for ICU tokenizer
[nominatim.git] / nominatim / api / search / icu_tokenizer.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2023 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Implementation of query analysis for the ICU tokenizer.
9 """
10 from typing import Tuple, Dict, List, Optional, NamedTuple, Iterator, Any, cast
11 from copy import copy
12 from collections import defaultdict
13 import dataclasses
14 import difflib
15
16 from icu import Transliterator
17
18 import sqlalchemy as sa
19
20 from nominatim.typing import SaRow
21 from nominatim.api.connection import SearchConnection
22 from nominatim.api.logging import log
23 from nominatim.api.search import query as qmod
24
25 # XXX: TODO
26 class AbstractQueryAnalyzer:
27     pass
28
29
30 DB_TO_TOKEN_TYPE = {
31     'W': qmod.TokenType.WORD,
32     'w': qmod.TokenType.PARTIAL,
33     'H': qmod.TokenType.HOUSENUMBER,
34     'P': qmod.TokenType.POSTCODE,
35     'C': qmod.TokenType.COUNTRY
36 }
37
38
39 class QueryPart(NamedTuple):
40     """ Normalized and transliterated form of a single term in the query.
41         When the term came out of a split during the transliteration,
42         the normalized string is the full word before transliteration.
43         The word number keeps track of the word before transliteration
44         and can be used to identify partial transliterated terms.
45     """
46     token: str
47     normalized: str
48     word_number: int
49
50
51 QueryParts = List[QueryPart]
52 WordDict = Dict[str, List[qmod.TokenRange]]
53
54 def yield_words(terms: List[QueryPart], start: int) -> Iterator[Tuple[str, qmod.TokenRange]]:
55     """ Return all combinations of words in the terms list after the
56         given position.
57     """
58     total = len(terms)
59     for first in range(start, total):
60         word = terms[first].token
61         yield word, qmod.TokenRange(first, first + 1)
62         for last in range(first + 1, min(first + 20, total)):
63             word = ' '.join((word, terms[last].token))
64             yield word, qmod.TokenRange(first, last + 1)
65
66
67 @dataclasses.dataclass
68 class ICUToken(qmod.Token):
69     """ Specialised token for ICU tokenizer.
70     """
71     word_token: str
72     info: Optional[Dict[str, Any]]
73
74     def get_category(self) -> Tuple[str, str]:
75         assert self.info
76         return self.info.get('class', ''), self.info.get('type', '')
77
78
79     def rematch(self, norm: str) -> None:
80         """ Check how well the token matches the given normalized string
81             and add a penalty, if necessary.
82         """
83         if not self.lookup_word:
84             return
85
86         seq = difflib.SequenceMatcher(a=self.lookup_word, b=norm)
87         distance = 0
88         for tag, afrom, ato, bfrom, bto in seq.get_opcodes():
89             if tag == 'delete' and (afrom == 0 or ato == len(self.lookup_word)):
90                 distance += 1
91             elif tag == 'replace':
92                 distance += max((ato-afrom), (bto-bfrom))
93             elif tag != 'equal':
94                 distance += abs((ato-afrom) - (bto-bfrom))
95         self.penalty += (distance/len(self.lookup_word))
96
97
98     @staticmethod
99     def from_db_row(row: SaRow) -> 'ICUToken':
100         """ Create a ICUToken from the row of the word table.
101         """
102         count = 1 if row.info is None else row.info.get('count', 1)
103
104         penalty = 0.0
105         if row.type == 'w':
106             penalty = 0.3
107         elif row.type == 'H':
108             penalty = sum(0.1 for c in row.word_token if c != ' ' and not c.isdigit())
109             if all(not c.isdigit() for c in row.word_token):
110                 penalty += 0.2 * (len(row.word_token) - 1)
111
112         if row.info is None:
113             lookup_word = row.word
114         else:
115             lookup_word = row.info.get('lookup', row.word)
116         if lookup_word:
117             lookup_word = lookup_word.split('@', 1)[0]
118         else:
119             lookup_word = row.word_token
120
121         return ICUToken(penalty=penalty, token=row.word_id, count=count,
122                         lookup_word=lookup_word, is_indexed=True,
123                         word_token=row.word_token, info=row.info)
124
125
126
127 class ICUQueryAnalyzer(AbstractQueryAnalyzer):
128     """ Converter for query strings into a tokenized query
129         using the tokens created by a ICU tokenizer.
130     """
131
132     def __init__(self, conn: SearchConnection) -> None:
133         self.conn = conn
134
135
136     async def setup(self) -> None:
137         """ Set up static data structures needed for the analysis.
138         """
139         rules = await self.conn.get_property('tokenizer_import_normalisation')
140         self.normalizer = Transliterator.createFromRules("normalization", rules)
141         rules = await self.conn.get_property('tokenizer_import_transliteration')
142         self.transliterator = Transliterator.createFromRules("transliteration", rules)
143
144         if 'word' not in self.conn.t.meta.tables:
145             sa.Table('word', self.conn.t.meta,
146                      sa.Column('word_id', sa.Integer),
147                      sa.Column('word_token', sa.Text, nullable=False),
148                      sa.Column('type', sa.Text, nullable=False),
149                      sa.Column('word', sa.Text),
150                      sa.Column('info', self.conn.t.types.Json))
151
152
153     async def analyze_query(self, phrases: List[qmod.Phrase]) -> qmod.QueryStruct:
154         """ Analyze the given list of phrases and return the
155             tokenized query.
156         """
157         log().section('Analyze query (using ICU tokenizer)')
158         normalized = list(filter(lambda p: p.text,
159                                  (qmod.Phrase(p.ptype, self.normalizer.transliterate(p.text))
160                                   for p in phrases)))
161         query = qmod.QueryStruct(normalized)
162         log().var_dump('Normalized query', query.source)
163         if not query.source:
164             return query
165
166         parts, words = self.split_query(query)
167         log().var_dump('Transliterated query', lambda: _dump_transliterated(query, parts))
168
169         for row in await self.lookup_in_db(list(words.keys())):
170             for trange in words[row.word_token]:
171                 token = ICUToken.from_db_row(row)
172                 if row.type == 'S':
173                     if row.info['op'] in ('in', 'near'):
174                         if trange.start == 0:
175                             query.add_token(trange, qmod.TokenType.CATEGORY, token)
176                     else:
177                         query.add_token(trange, qmod.TokenType.QUALIFIER, token)
178                         if trange.start == 0 or trange.end == query.num_token_slots():
179                             token = copy(token)
180                             token.penalty += 0.1 * (query.num_token_slots())
181                             query.add_token(trange, qmod.TokenType.CATEGORY, token)
182                 else:
183                     query.add_token(trange, DB_TO_TOKEN_TYPE[row.type], token)
184
185         self.add_extra_tokens(query, parts)
186         self.rerank_tokens(query, parts)
187
188         log().table_dump('Word tokens', _dump_word_tokens(query))
189
190         return query
191
192
193     def split_query(self, query: qmod.QueryStruct) -> Tuple[QueryParts, WordDict]:
194         """ Transliterate the phrases and split them into tokens.
195
196             Returns the list of transliterated tokens together with their
197             normalized form and a dictionary of words for lookup together
198             with their position.
199         """
200         parts: QueryParts = []
201         phrase_start = 0
202         words = defaultdict(list)
203         wordnr = 0
204         for phrase in query.source:
205             query.nodes[-1].ptype = phrase.ptype
206             for word in phrase.text.split(' '):
207                 trans = self.transliterator.transliterate(word)
208                 if trans:
209                     for term in trans.split(' '):
210                         if term:
211                             parts.append(QueryPart(term, word, wordnr))
212                             query.add_node(qmod.BreakType.TOKEN, phrase.ptype)
213                     query.nodes[-1].btype = qmod.BreakType.WORD
214                 wordnr += 1
215             query.nodes[-1].btype = qmod.BreakType.PHRASE
216
217             for word, wrange in yield_words(parts, phrase_start):
218                 words[word].append(wrange)
219
220             phrase_start = len(parts)
221         query.nodes[-1].btype = qmod.BreakType.END
222
223         return parts, words
224
225
226     async def lookup_in_db(self, words: List[str]) -> 'sa.Result[Any]':
227         """ Return the token information from the database for the
228             given word tokens.
229         """
230         t = self.conn.t.meta.tables['word']
231         return await self.conn.execute(t.select().where(t.c.word_token.in_(words)))
232
233
234     def add_extra_tokens(self, query: qmod.QueryStruct, parts: QueryParts) -> None:
235         """ Add tokens to query that are not saved in the database.
236         """
237         for part, node, i in zip(parts, query.nodes, range(1000)):
238             if len(part.token) <= 4 and part[0].isdigit()\
239                and not node.has_tokens(i+1, qmod.TokenType.HOUSENUMBER):
240                 query.add_token(qmod.TokenRange(i, i+1), qmod.TokenType.HOUSENUMBER,
241                                 ICUToken(0.5, 0, 1, part.token, True, part.token, None))
242
243
244     def rerank_tokens(self, query: qmod.QueryStruct, parts: QueryParts) -> None:
245         """ Add penalties to tokens that depend on presence of other token.
246         """
247         for i, node, tlist in query.iter_token_lists():
248             if tlist.ttype == qmod.TokenType.POSTCODE:
249                 for repl in node.starting:
250                     if repl.end == tlist.end and repl.ttype != qmod.TokenType.POSTCODE \
251                        and (repl.ttype != qmod.TokenType.HOUSENUMBER
252                             or len(tlist.tokens[0].lookup_word) > 4):
253                         repl.add_penalty(0.39)
254             elif tlist.ttype == qmod.TokenType.HOUSENUMBER:
255                 if any(c.isdigit() for c in tlist.tokens[0].lookup_word):
256                     for repl in node.starting:
257                         if repl.end == tlist.end and repl.ttype != qmod.TokenType.HOUSENUMBER \
258                            and (repl.ttype != qmod.TokenType.HOUSENUMBER
259                                 or len(tlist.tokens[0].lookup_word) <= 3):
260                             repl.add_penalty(0.5 - tlist.tokens[0].penalty)
261             elif tlist.ttype not in (qmod.TokenType.COUNTRY, qmod.TokenType.PARTIAL):
262                 norm = parts[i].normalized
263                 for j in range(i + 1, tlist.end):
264                     if parts[j - 1].word_number != parts[j].word_number:
265                         norm += '  ' + parts[j].normalized
266                 for token in tlist.tokens:
267                     cast(ICUToken, token).rematch(norm)
268
269
270 def _dump_transliterated(query: qmod.QueryStruct, parts: QueryParts) -> str:
271     out = query.nodes[0].btype.value
272     for node, part in zip(query.nodes[1:], parts):
273         out += part.token + node.btype.value
274     return out
275
276
277 def _dump_word_tokens(query: qmod.QueryStruct) -> Iterator[List[Any]]:
278     yield ['type', 'token', 'word_token', 'lookup_word', 'penalty', 'count', 'info']
279     for node in query.nodes:
280         for tlist in node.starting:
281             for token in tlist.tokens:
282                 t = cast(ICUToken, token)
283                 yield [tlist.ttype.name, t.token, t.word_token or '',
284                        t.lookup_word or '', t.penalty, t.count, t.info]
285
286
287 async def create_query_analyzer(conn: SearchConnection) -> AbstractQueryAnalyzer:
288     """ Create and set up a new query analyzer for a database based
289         on the ICU tokenizer.
290     """
291     out = ICUQueryAnalyzer(conn)
292     await out.setup()
293
294     return out