]> git.openstreetmap.org Git - nominatim.git/blob - nominatim/api/search/legacy_tokenizer.py
block search queries with too many tokens
[nominatim.git] / nominatim / api / search / legacy_tokenizer.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2023 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Implementation of query analysis for the legacy tokenizer.
9 """
10 from typing import Tuple, Dict, List, Optional, Iterator, Any, cast
11 from copy import copy
12 from collections import defaultdict
13 import dataclasses
14
15 import sqlalchemy as sa
16
17 from nominatim.typing import SaRow
18 from nominatim.api.connection import SearchConnection
19 from nominatim.api.logging import log
20 from nominatim.api.search import query as qmod
21 from nominatim.api.search.query_analyzer_factory import AbstractQueryAnalyzer
22
23 def yield_words(terms: List[str], start: int) -> Iterator[Tuple[str, qmod.TokenRange]]:
24     """ Return all combinations of words in the terms list after the
25         given position.
26     """
27     total = len(terms)
28     for first in range(start, total):
29         word = terms[first]
30         yield word, qmod.TokenRange(first, first + 1)
31         for last in range(first + 1, min(first + 20, total)):
32             word = ' '.join((word, terms[last]))
33             yield word, qmod.TokenRange(first, last + 1)
34
35
36 @dataclasses.dataclass
37 class LegacyToken(qmod.Token):
38     """ Specialised token for legacy tokenizer.
39     """
40     word_token: str
41     category: Optional[Tuple[str, str]]
42     country: Optional[str]
43     operator: Optional[str]
44
45     @property
46     def info(self) -> Dict[str, Any]:
47         """ Dictionary of additional propoerties of the token.
48             Should only be used for debugging purposes.
49         """
50         return {'category': self.category,
51                 'country': self.country,
52                 'operator': self.operator}
53
54
55     def get_category(self) -> Tuple[str, str]:
56         assert self.category
57         return self.category
58
59
60 class LegacyQueryAnalyzer(AbstractQueryAnalyzer):
61     """ Converter for query strings into a tokenized query
62         using the tokens created by a legacy tokenizer.
63     """
64
65     def __init__(self, conn: SearchConnection) -> None:
66         self.conn = conn
67
68     async def setup(self) -> None:
69         """ Set up static data structures needed for the analysis.
70         """
71         self.max_word_freq = int(await self.conn.get_property('tokenizer_maxwordfreq'))
72         if 'word' not in self.conn.t.meta.tables:
73             sa.Table('word', self.conn.t.meta,
74                      sa.Column('word_id', sa.Integer),
75                      sa.Column('word_token', sa.Text, nullable=False),
76                      sa.Column('word', sa.Text),
77                      sa.Column('class', sa.Text),
78                      sa.Column('type', sa.Text),
79                      sa.Column('country_code', sa.Text),
80                      sa.Column('search_name_count', sa.Integer),
81                      sa.Column('operator', sa.Text))
82
83
84     async def analyze_query(self, phrases: List[qmod.Phrase]) -> qmod.QueryStruct:
85         """ Analyze the given list of phrases and return the
86             tokenized query.
87         """
88         log().section('Analyze query (using Legacy tokenizer)')
89
90         normalized = []
91         if phrases:
92             for row in await self.conn.execute(sa.select(*(sa.func.make_standard_name(p.text)
93                                                            for p in phrases))):
94                 normalized = [qmod.Phrase(p.ptype, r) for r, p in zip(row, phrases) if r]
95                 break
96
97         query = qmod.QueryStruct(normalized)
98         log().var_dump('Normalized query', query.source)
99         if not query.source:
100             return query
101
102         parts, words = self.split_query(query)
103         lookup_words = list(words.keys())
104         log().var_dump('Split query', parts)
105         log().var_dump('Extracted words', lookup_words)
106
107         for row in await self.lookup_in_db(lookup_words):
108             for trange in words[row.word_token.strip()]:
109                 token, ttype = self.make_token(row)
110                 if ttype == qmod.TokenType.CATEGORY:
111                     if trange.start == 0:
112                         query.add_token(trange, qmod.TokenType.CATEGORY, token)
113                 elif ttype == qmod.TokenType.QUALIFIER:
114                     query.add_token(trange, qmod.TokenType.QUALIFIER, token)
115                     if trange.start == 0 or trange.end == query.num_token_slots():
116                         token = copy(token)
117                         token.penalty += 0.1 * (query.num_token_slots())
118                         query.add_token(trange, qmod.TokenType.CATEGORY, token)
119                 elif ttype != qmod.TokenType.PARTIAL or trange.start + 1 == trange.end:
120                     query.add_token(trange, ttype, token)
121
122         self.add_extra_tokens(query, parts)
123         self.rerank_tokens(query)
124
125         log().table_dump('Word tokens', _dump_word_tokens(query))
126
127         return query
128
129
130     def split_query(self, query: qmod.QueryStruct) -> Tuple[List[str],
131                                                             Dict[str, List[qmod.TokenRange]]]:
132         """ Transliterate the phrases and split them into tokens.
133
134             Returns a list of transliterated tokens and a dictionary
135             of words for lookup together with their position.
136         """
137         parts: List[str] = []
138         phrase_start = 0
139         words = defaultdict(list)
140         for phrase in query.source:
141             query.nodes[-1].ptype = phrase.ptype
142             for trans in phrase.text.split(' '):
143                 if trans:
144                     for term in trans.split(' '):
145                         if term:
146                             parts.append(trans)
147                             query.add_node(qmod.BreakType.TOKEN, phrase.ptype)
148                     query.nodes[-1].btype = qmod.BreakType.WORD
149             query.nodes[-1].btype = qmod.BreakType.PHRASE
150             for word, wrange in yield_words(parts, phrase_start):
151                 words[word].append(wrange)
152             phrase_start = len(parts)
153         query.nodes[-1].btype = qmod.BreakType.END
154
155         return parts, words
156
157
158     async def lookup_in_db(self, words: List[str]) -> 'sa.Result[Any]':
159         """ Return the token information from the database for the
160             given word tokens.
161         """
162         t = self.conn.t.meta.tables['word']
163
164         sql = t.select().where(t.c.word_token.in_(words + [' ' + w for w in words]))
165
166         return await self.conn.execute(sql)
167
168
169     def make_token(self, row: SaRow) -> Tuple[LegacyToken, qmod.TokenType]:
170         """ Create a LegacyToken from the row of the word table.
171             Also determines the type of token.
172         """
173         penalty = 0.0
174         is_indexed = True
175
176         rowclass = getattr(row, 'class')
177
178         if row.country_code is not None:
179             ttype = qmod.TokenType.COUNTRY
180             lookup_word = row.country_code
181         elif rowclass is not None:
182             if rowclass == 'place' and  row.type == 'house':
183                 ttype = qmod.TokenType.HOUSENUMBER
184                 lookup_word = row.word_token[1:]
185             elif rowclass == 'place' and  row.type == 'postcode':
186                 ttype = qmod.TokenType.POSTCODE
187                 lookup_word = row.word_token[1:]
188             else:
189                 ttype = qmod.TokenType.CATEGORY if row.operator in ('in', 'near')\
190                         else qmod.TokenType.QUALIFIER
191                 lookup_word = row.word
192         elif row.word_token.startswith(' '):
193             ttype = qmod.TokenType.WORD
194             lookup_word = row.word or row.word_token[1:]
195         else:
196             ttype = qmod.TokenType.PARTIAL
197             lookup_word = row.word_token
198             penalty = 0.21
199             if row.search_name_count > self.max_word_freq:
200                 is_indexed = False
201
202         return LegacyToken(penalty=penalty, token=row.word_id,
203                            count=row.search_name_count or 1,
204                            lookup_word=lookup_word,
205                            word_token=row.word_token.strip(),
206                            category=(rowclass, row.type) if rowclass is not None else None,
207                            country=row.country_code,
208                            operator=row.operator,
209                            is_indexed=is_indexed),\
210                ttype
211
212
213     def add_extra_tokens(self, query: qmod.QueryStruct, parts: List[str]) -> None:
214         """ Add tokens to query that are not saved in the database.
215         """
216         for part, node, i in zip(parts, query.nodes, range(1000)):
217             if len(part) <= 4 and part.isdigit()\
218                and not node.has_tokens(i+1, qmod.TokenType.HOUSENUMBER):
219                 query.add_token(qmod.TokenRange(i, i+1), qmod.TokenType.HOUSENUMBER,
220                                 LegacyToken(penalty=0.5, token=0, count=1,
221                                             lookup_word=part, word_token=part,
222                                             category=None, country=None,
223                                             operator=None, is_indexed=True))
224
225
226     def rerank_tokens(self, query: qmod.QueryStruct) -> None:
227         """ Add penalties to tokens that depend on presence of other token.
228         """
229         for _, node, tlist in query.iter_token_lists():
230             if tlist.ttype == qmod.TokenType.POSTCODE:
231                 for repl in node.starting:
232                     if repl.end == tlist.end and repl.ttype != qmod.TokenType.POSTCODE \
233                        and (repl.ttype != qmod.TokenType.HOUSENUMBER
234                             or len(tlist.tokens[0].lookup_word) > 4):
235                         repl.add_penalty(0.39)
236             elif tlist.ttype == qmod.TokenType.HOUSENUMBER \
237                  and len(tlist.tokens[0].lookup_word) <= 3:
238                 if any(c.isdigit() for c in tlist.tokens[0].lookup_word):
239                     for repl in node.starting:
240                         if repl.end == tlist.end and repl.ttype != qmod.TokenType.HOUSENUMBER:
241                             repl.add_penalty(0.5 - tlist.tokens[0].penalty)
242
243
244
245 def _dump_word_tokens(query: qmod.QueryStruct) -> Iterator[List[Any]]:
246     yield ['type', 'token', 'word_token', 'lookup_word', 'penalty', 'count', 'info']
247     for node in query.nodes:
248         for tlist in node.starting:
249             for token in tlist.tokens:
250                 t = cast(LegacyToken, token)
251                 yield [tlist.ttype.name, t.token, t.word_token or '',
252                        t.lookup_word or '', t.penalty, t.count, t.info]
253
254
255 async def create_query_analyzer(conn: SearchConnection) -> AbstractQueryAnalyzer:
256     """ Create and set up a new query analyzer for a database based
257         on the ICU tokenizer.
258     """
259     out = LegacyQueryAnalyzer(conn)
260     await out.setup()
261
262     return out