]> git.openstreetmap.org Git - nominatim.git/blob - src/nominatim_api/search/legacy_tokenizer.py
fix some typos
[nominatim.git] / src / nominatim_api / search / legacy_tokenizer.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2024 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Implementation of query analysis for the legacy tokenizer.
9 """
10 from typing import Tuple, Dict, List, Optional, Iterator, Any, cast
11 from copy import copy
12 from collections import defaultdict
13 import dataclasses
14
15 import sqlalchemy as sa
16
17 from ..typing import SaRow
18 from ..connection import SearchConnection
19 from ..logging import log
20 from . import query as qmod
21 from .query_analyzer_factory import AbstractQueryAnalyzer
22
23 def yield_words(terms: List[str], start: int) -> Iterator[Tuple[str, qmod.TokenRange]]:
24     """ Return all combinations of words in the terms list after the
25         given position.
26     """
27     total = len(terms)
28     for first in range(start, total):
29         word = terms[first]
30         yield word, qmod.TokenRange(first, first + 1)
31         for last in range(first + 1, min(first + 20, total)):
32             word = ' '.join((word, terms[last]))
33             yield word, qmod.TokenRange(first, last + 1)
34
35
36 @dataclasses.dataclass
37 class LegacyToken(qmod.Token):
38     """ Specialised token for legacy tokenizer.
39     """
40     word_token: str
41     category: Optional[Tuple[str, str]]
42     country: Optional[str]
43     operator: Optional[str]
44
45     @property
46     def info(self) -> Dict[str, Any]:
47         """ Dictionary of additional properties of the token.
48             Should only be used for debugging purposes.
49         """
50         return {'category': self.category,
51                 'country': self.country,
52                 'operator': self.operator}
53
54
55     def get_category(self) -> Tuple[str, str]:
56         assert self.category
57         return self.category
58
59
60 class LegacyQueryAnalyzer(AbstractQueryAnalyzer):
61     """ Converter for query strings into a tokenized query
62         using the tokens created by a legacy tokenizer.
63     """
64
65     def __init__(self, conn: SearchConnection) -> None:
66         self.conn = conn
67
68     async def setup(self) -> None:
69         """ Set up static data structures needed for the analysis.
70         """
71         self.max_word_freq = int(await self.conn.get_property('tokenizer_maxwordfreq'))
72         if 'word' not in self.conn.t.meta.tables:
73             sa.Table('word', self.conn.t.meta,
74                      sa.Column('word_id', sa.Integer),
75                      sa.Column('word_token', sa.Text, nullable=False),
76                      sa.Column('word', sa.Text),
77                      sa.Column('class', sa.Text),
78                      sa.Column('type', sa.Text),
79                      sa.Column('country_code', sa.Text),
80                      sa.Column('search_name_count', sa.Integer),
81                      sa.Column('operator', sa.Text))
82
83
84     async def analyze_query(self, phrases: List[qmod.Phrase]) -> qmod.QueryStruct:
85         """ Analyze the given list of phrases and return the
86             tokenized query.
87         """
88         log().section('Analyze query (using Legacy tokenizer)')
89
90         normalized = []
91         if phrases:
92             for row in await self.conn.execute(sa.select(*(sa.func.make_standard_name(p.text)
93                                                            for p in phrases))):
94                 normalized = [qmod.Phrase(p.ptype, r) for r, p in zip(row, phrases) if r]
95                 break
96
97         query = qmod.QueryStruct(normalized)
98         log().var_dump('Normalized query', query.source)
99         if not query.source:
100             return query
101
102         parts, words = self.split_query(query)
103         lookup_words = list(words.keys())
104         log().var_dump('Split query', parts)
105         log().var_dump('Extracted words', lookup_words)
106
107         for row in await self.lookup_in_db(lookup_words):
108             for trange in words[row.word_token.strip()]:
109                 token, ttype = self.make_token(row)
110                 if ttype == qmod.TokenType.NEAR_ITEM:
111                     if trange.start == 0:
112                         query.add_token(trange, qmod.TokenType.NEAR_ITEM, token)
113                 elif ttype == qmod.TokenType.QUALIFIER:
114                     query.add_token(trange, qmod.TokenType.QUALIFIER, token)
115                     if trange.start == 0 or trange.end == query.num_token_slots():
116                         token = copy(token)
117                         token.penalty += 0.1 * (query.num_token_slots())
118                         query.add_token(trange, qmod.TokenType.NEAR_ITEM, token)
119                 elif ttype != qmod.TokenType.PARTIAL or trange.start + 1 == trange.end:
120                     query.add_token(trange, ttype, token)
121
122         self.add_extra_tokens(query, parts)
123         self.rerank_tokens(query)
124
125         log().table_dump('Word tokens', _dump_word_tokens(query))
126
127         return query
128
129
130     def normalize_text(self, text: str) -> str:
131         """ Bring the given text into a normalized form.
132
133             This only removes case, so some difference with the normalization
134             in the phrase remains.
135         """
136         return text.lower()
137
138
139     def split_query(self, query: qmod.QueryStruct) -> Tuple[List[str],
140                                                             Dict[str, List[qmod.TokenRange]]]:
141         """ Transliterate the phrases and split them into tokens.
142
143             Returns a list of transliterated tokens and a dictionary
144             of words for lookup together with their position.
145         """
146         parts: List[str] = []
147         phrase_start = 0
148         words = defaultdict(list)
149         for phrase in query.source:
150             query.nodes[-1].ptype = phrase.ptype
151             for trans in phrase.text.split(' '):
152                 if trans:
153                     for term in trans.split(' '):
154                         if term:
155                             parts.append(trans)
156                             query.add_node(qmod.BreakType.TOKEN, phrase.ptype)
157                     query.nodes[-1].btype = qmod.BreakType.WORD
158             query.nodes[-1].btype = qmod.BreakType.PHRASE
159             for word, wrange in yield_words(parts, phrase_start):
160                 words[word].append(wrange)
161             phrase_start = len(parts)
162         query.nodes[-1].btype = qmod.BreakType.END
163
164         return parts, words
165
166
167     async def lookup_in_db(self, words: List[str]) -> 'sa.Result[Any]':
168         """ Return the token information from the database for the
169             given word tokens.
170         """
171         t = self.conn.t.meta.tables['word']
172
173         sql = t.select().where(t.c.word_token.in_(words + [' ' + w for w in words]))
174
175         return await self.conn.execute(sql)
176
177
178     def make_token(self, row: SaRow) -> Tuple[LegacyToken, qmod.TokenType]:
179         """ Create a LegacyToken from the row of the word table.
180             Also determines the type of token.
181         """
182         penalty = 0.0
183         is_indexed = True
184
185         rowclass = getattr(row, 'class')
186
187         if row.country_code is not None:
188             ttype = qmod.TokenType.COUNTRY
189             lookup_word = row.country_code
190         elif rowclass is not None:
191             if rowclass == 'place' and  row.type == 'house':
192                 ttype = qmod.TokenType.HOUSENUMBER
193                 lookup_word = row.word_token[1:]
194             elif rowclass == 'place' and  row.type == 'postcode':
195                 ttype = qmod.TokenType.POSTCODE
196                 lookup_word = row.word_token[1:]
197             else:
198                 ttype = qmod.TokenType.NEAR_ITEM if row.operator in ('in', 'near')\
199                         else qmod.TokenType.QUALIFIER
200                 lookup_word = row.word
201         elif row.word_token.startswith(' '):
202             ttype = qmod.TokenType.WORD
203             lookup_word = row.word or row.word_token[1:]
204         else:
205             ttype = qmod.TokenType.PARTIAL
206             lookup_word = row.word_token
207             penalty = 0.21
208             if row.search_name_count > self.max_word_freq:
209                 is_indexed = False
210
211         return LegacyToken(penalty=penalty, token=row.word_id,
212                            count=max(1, row.search_name_count or 1),
213                            addr_count=1, # not supported
214                            lookup_word=lookup_word,
215                            word_token=row.word_token.strip(),
216                            category=(rowclass, row.type) if rowclass is not None else None,
217                            country=row.country_code,
218                            operator=row.operator,
219                            is_indexed=is_indexed),\
220                ttype
221
222
223     def add_extra_tokens(self, query: qmod.QueryStruct, parts: List[str]) -> None:
224         """ Add tokens to query that are not saved in the database.
225         """
226         for part, node, i in zip(parts, query.nodes, range(1000)):
227             if len(part) <= 4 and part.isdigit()\
228                and not node.has_tokens(i+1, qmod.TokenType.HOUSENUMBER):
229                 query.add_token(qmod.TokenRange(i, i+1), qmod.TokenType.HOUSENUMBER,
230                                 LegacyToken(penalty=0.5, token=0, count=1, addr_count=1,
231                                             lookup_word=part, word_token=part,
232                                             category=None, country=None,
233                                             operator=None, is_indexed=True))
234
235
236     def rerank_tokens(self, query: qmod.QueryStruct) -> None:
237         """ Add penalties to tokens that depend on presence of other token.
238         """
239         for _, node, tlist in query.iter_token_lists():
240             if tlist.ttype == qmod.TokenType.POSTCODE:
241                 for repl in node.starting:
242                     if repl.end == tlist.end and repl.ttype != qmod.TokenType.POSTCODE \
243                        and (repl.ttype != qmod.TokenType.HOUSENUMBER
244                             or len(tlist.tokens[0].lookup_word) > 4):
245                         repl.add_penalty(0.39)
246             elif tlist.ttype == qmod.TokenType.HOUSENUMBER \
247                  and len(tlist.tokens[0].lookup_word) <= 3:
248                 if any(c.isdigit() for c in tlist.tokens[0].lookup_word):
249                     for repl in node.starting:
250                         if repl.end == tlist.end and repl.ttype != qmod.TokenType.HOUSENUMBER:
251                             repl.add_penalty(0.5 - tlist.tokens[0].penalty)
252
253
254
255 def _dump_word_tokens(query: qmod.QueryStruct) -> Iterator[List[Any]]:
256     yield ['type', 'token', 'word_token', 'lookup_word', 'penalty', 'count', 'info']
257     for node in query.nodes:
258         for tlist in node.starting:
259             for token in tlist.tokens:
260                 t = cast(LegacyToken, token)
261                 yield [tlist.ttype.name, t.token, t.word_token or '',
262                        t.lookup_word or '', t.penalty, t.count, t.info]
263
264
265 async def create_query_analyzer(conn: SearchConnection) -> AbstractQueryAnalyzer:
266     """ Create and set up a new query analyzer for a database based
267         on the ICU tokenizer.
268     """
269     out = LegacyQueryAnalyzer(conn)
270     await out.setup()
271
272     return out