]> git.openstreetmap.org Git - nominatim.git/blob - nominatim/api/search/legacy_tokenizer.py
prefilter bad results before adding details and reranking
[nominatim.git] / nominatim / api / search / legacy_tokenizer.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2023 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Implementation of query analysis for the legacy tokenizer.
9 """
10 from typing import Tuple, Dict, List, Optional, Iterator, Any, cast
11 from copy import copy
12 from collections import defaultdict
13 import dataclasses
14
15 import sqlalchemy as sa
16
17 from nominatim.typing import SaRow
18 from nominatim.api.connection import SearchConnection
19 from nominatim.api.logging import log
20 from nominatim.api.search import query as qmod
21 from nominatim.api.search.query_analyzer_factory import AbstractQueryAnalyzer
22
23 def yield_words(terms: List[str], start: int) -> Iterator[Tuple[str, qmod.TokenRange]]:
24     """ Return all combinations of words in the terms list after the
25         given position.
26     """
27     total = len(terms)
28     for first in range(start, total):
29         word = terms[first]
30         yield word, qmod.TokenRange(first, first + 1)
31         for last in range(first + 1, min(first + 20, total)):
32             word = ' '.join((word, terms[last]))
33             yield word, qmod.TokenRange(first, last + 1)
34
35
36 @dataclasses.dataclass
37 class LegacyToken(qmod.Token):
38     """ Specialised token for legacy tokenizer.
39     """
40     word_token: str
41     category: Optional[Tuple[str, str]]
42     country: Optional[str]
43     operator: Optional[str]
44
45     @property
46     def info(self) -> Dict[str, Any]:
47         """ Dictionary of additional propoerties of the token.
48             Should only be used for debugging purposes.
49         """
50         return {'category': self.category,
51                 'country': self.country,
52                 'operator': self.operator}
53
54
55     def get_category(self) -> Tuple[str, str]:
56         assert self.category
57         return self.category
58
59
60 class LegacyQueryAnalyzer(AbstractQueryAnalyzer):
61     """ Converter for query strings into a tokenized query
62         using the tokens created by a legacy tokenizer.
63     """
64
65     def __init__(self, conn: SearchConnection) -> None:
66         self.conn = conn
67
68     async def setup(self) -> None:
69         """ Set up static data structures needed for the analysis.
70         """
71         self.max_word_freq = int(await self.conn.get_property('tokenizer_maxwordfreq'))
72         if 'word' not in self.conn.t.meta.tables:
73             sa.Table('word', self.conn.t.meta,
74                      sa.Column('word_id', sa.Integer),
75                      sa.Column('word_token', sa.Text, nullable=False),
76                      sa.Column('word', sa.Text),
77                      sa.Column('class', sa.Text),
78                      sa.Column('type', sa.Text),
79                      sa.Column('country_code', sa.Text),
80                      sa.Column('search_name_count', sa.Integer),
81                      sa.Column('operator', sa.Text))
82
83
84     async def analyze_query(self, phrases: List[qmod.Phrase]) -> qmod.QueryStruct:
85         """ Analyze the given list of phrases and return the
86             tokenized query.
87         """
88         log().section('Analyze query (using Legacy tokenizer)')
89
90         normalized = []
91         if phrases:
92             for row in await self.conn.execute(sa.select(*(sa.func.make_standard_name(p.text)
93                                                            for p in phrases))):
94                 normalized = [qmod.Phrase(p.ptype, r) for r, p in zip(row, phrases) if r]
95                 break
96
97         query = qmod.QueryStruct(normalized)
98         log().var_dump('Normalized query', query.source)
99         if not query.source:
100             return query
101
102         parts, words = self.split_query(query)
103         lookup_words = list(words.keys())
104         log().var_dump('Split query', parts)
105         log().var_dump('Extracted words', lookup_words)
106
107         for row in await self.lookup_in_db(lookup_words):
108             for trange in words[row.word_token.strip()]:
109                 token, ttype = self.make_token(row)
110                 if ttype == qmod.TokenType.NEAR_ITEM:
111                     if trange.start == 0:
112                         query.add_token(trange, qmod.TokenType.NEAR_ITEM, token)
113                 elif ttype == qmod.TokenType.QUALIFIER:
114                     query.add_token(trange, qmod.TokenType.QUALIFIER, token)
115                     if trange.start == 0 or trange.end == query.num_token_slots():
116                         token = copy(token)
117                         token.penalty += 0.1 * (query.num_token_slots())
118                         query.add_token(trange, qmod.TokenType.NEAR_ITEM, token)
119                 elif ttype != qmod.TokenType.PARTIAL or trange.start + 1 == trange.end:
120                     query.add_token(trange, ttype, token)
121
122         self.add_extra_tokens(query, parts)
123         self.rerank_tokens(query)
124
125         log().table_dump('Word tokens', _dump_word_tokens(query))
126
127         return query
128
129
130     def normalize_text(self, text: str) -> str:
131         """ Bring the given text into a normalized form.
132
133             This only removes case, so some difference with the normalization
134             in the phrase remains.
135         """
136         return text.lower()
137
138
139     def split_query(self, query: qmod.QueryStruct) -> Tuple[List[str],
140                                                             Dict[str, List[qmod.TokenRange]]]:
141         """ Transliterate the phrases and split them into tokens.
142
143             Returns a list of transliterated tokens and a dictionary
144             of words for lookup together with their position.
145         """
146         parts: List[str] = []
147         phrase_start = 0
148         words = defaultdict(list)
149         for phrase in query.source:
150             query.nodes[-1].ptype = phrase.ptype
151             for trans in phrase.text.split(' '):
152                 if trans:
153                     for term in trans.split(' '):
154                         if term:
155                             parts.append(trans)
156                             query.add_node(qmod.BreakType.TOKEN, phrase.ptype)
157                     query.nodes[-1].btype = qmod.BreakType.WORD
158             query.nodes[-1].btype = qmod.BreakType.PHRASE
159             for word, wrange in yield_words(parts, phrase_start):
160                 words[word].append(wrange)
161             phrase_start = len(parts)
162         query.nodes[-1].btype = qmod.BreakType.END
163
164         return parts, words
165
166
167     async def lookup_in_db(self, words: List[str]) -> 'sa.Result[Any]':
168         """ Return the token information from the database for the
169             given word tokens.
170         """
171         t = self.conn.t.meta.tables['word']
172
173         sql = t.select().where(t.c.word_token.in_(words + [' ' + w for w in words]))
174
175         return await self.conn.execute(sql)
176
177
178     def make_token(self, row: SaRow) -> Tuple[LegacyToken, qmod.TokenType]:
179         """ Create a LegacyToken from the row of the word table.
180             Also determines the type of token.
181         """
182         penalty = 0.0
183         is_indexed = True
184
185         rowclass = getattr(row, 'class')
186
187         if row.country_code is not None:
188             ttype = qmod.TokenType.COUNTRY
189             lookup_word = row.country_code
190         elif rowclass is not None:
191             if rowclass == 'place' and  row.type == 'house':
192                 ttype = qmod.TokenType.HOUSENUMBER
193                 lookup_word = row.word_token[1:]
194             elif rowclass == 'place' and  row.type == 'postcode':
195                 ttype = qmod.TokenType.POSTCODE
196                 lookup_word = row.word_token[1:]
197             else:
198                 ttype = qmod.TokenType.NEAR_ITEM if row.operator in ('in', 'near')\
199                         else qmod.TokenType.QUALIFIER
200                 lookup_word = row.word
201         elif row.word_token.startswith(' '):
202             ttype = qmod.TokenType.WORD
203             lookup_word = row.word or row.word_token[1:]
204         else:
205             ttype = qmod.TokenType.PARTIAL
206             lookup_word = row.word_token
207             penalty = 0.21
208             if row.search_name_count > self.max_word_freq:
209                 is_indexed = False
210
211         return LegacyToken(penalty=penalty, token=row.word_id,
212                            count=row.search_name_count or 1,
213                            lookup_word=lookup_word,
214                            word_token=row.word_token.strip(),
215                            category=(rowclass, row.type) if rowclass is not None else None,
216                            country=row.country_code,
217                            operator=row.operator,
218                            is_indexed=is_indexed),\
219                ttype
220
221
222     def add_extra_tokens(self, query: qmod.QueryStruct, parts: List[str]) -> None:
223         """ Add tokens to query that are not saved in the database.
224         """
225         for part, node, i in zip(parts, query.nodes, range(1000)):
226             if len(part) <= 4 and part.isdigit()\
227                and not node.has_tokens(i+1, qmod.TokenType.HOUSENUMBER):
228                 query.add_token(qmod.TokenRange(i, i+1), qmod.TokenType.HOUSENUMBER,
229                                 LegacyToken(penalty=0.5, token=0, count=1,
230                                             lookup_word=part, word_token=part,
231                                             category=None, country=None,
232                                             operator=None, is_indexed=True))
233
234
235     def rerank_tokens(self, query: qmod.QueryStruct) -> None:
236         """ Add penalties to tokens that depend on presence of other token.
237         """
238         for _, node, tlist in query.iter_token_lists():
239             if tlist.ttype == qmod.TokenType.POSTCODE:
240                 for repl in node.starting:
241                     if repl.end == tlist.end and repl.ttype != qmod.TokenType.POSTCODE \
242                        and (repl.ttype != qmod.TokenType.HOUSENUMBER
243                             or len(tlist.tokens[0].lookup_word) > 4):
244                         repl.add_penalty(0.39)
245             elif tlist.ttype == qmod.TokenType.HOUSENUMBER \
246                  and len(tlist.tokens[0].lookup_word) <= 3:
247                 if any(c.isdigit() for c in tlist.tokens[0].lookup_word):
248                     for repl in node.starting:
249                         if repl.end == tlist.end and repl.ttype != qmod.TokenType.HOUSENUMBER:
250                             repl.add_penalty(0.5 - tlist.tokens[0].penalty)
251
252
253
254 def _dump_word_tokens(query: qmod.QueryStruct) -> Iterator[List[Any]]:
255     yield ['type', 'token', 'word_token', 'lookup_word', 'penalty', 'count', 'info']
256     for node in query.nodes:
257         for tlist in node.starting:
258             for token in tlist.tokens:
259                 t = cast(LegacyToken, token)
260                 yield [tlist.ttype.name, t.token, t.word_token or '',
261                        t.lookup_word or '', t.penalty, t.count, t.info]
262
263
264 async def create_query_analyzer(conn: SearchConnection) -> AbstractQueryAnalyzer:
265     """ Create and set up a new query analyzer for a database based
266         on the ICU tokenizer.
267     """
268     out = LegacyQueryAnalyzer(conn)
269     await out.setup()
270
271     return out