]> git.openstreetmap.org Git - nominatim.git/blobdiff - nominatim/api/search/icu_tokenizer.py
Merge pull request #3383 from lonvia/window-searches
[nominatim.git] / nominatim / api / search / icu_tokenizer.py
index 7bf516e3aa25c12775d0c3ffae7271076f6d0032..eb90c122eb43277a18b2d71bf48eb9ab99a375ac 100644 (file)
@@ -8,7 +8,6 @@
 Implementation of query analysis for the ICU tokenizer.
 """
 from typing import Tuple, Dict, List, Optional, NamedTuple, Iterator, Any, cast
 Implementation of query analysis for the ICU tokenizer.
 """
 from typing import Tuple, Dict, List, Optional, NamedTuple, Iterator, Any, cast
-from copy import copy
 from collections import defaultdict
 import dataclasses
 import difflib
 from collections import defaultdict
 import dataclasses
 import difflib
@@ -22,6 +21,7 @@ from nominatim.api.connection import SearchConnection
 from nominatim.api.logging import log
 from nominatim.api.search import query as qmod
 from nominatim.api.search.query_analyzer_factory import AbstractQueryAnalyzer
 from nominatim.api.logging import log
 from nominatim.api.search import query as qmod
 from nominatim.api.search.query_analyzer_factory import AbstractQueryAnalyzer
+from nominatim.db.sqlalchemy_types import Json
 
 
 DB_TO_TOKEN_TYPE = {
 
 
 DB_TO_TOKEN_TYPE = {
@@ -97,14 +97,21 @@ class ICUToken(qmod.Token):
         """ Create a ICUToken from the row of the word table.
         """
         count = 1 if row.info is None else row.info.get('count', 1)
         """ Create a ICUToken from the row of the word table.
         """
         count = 1 if row.info is None else row.info.get('count', 1)
+        addr_count = 1 if row.info is None else row.info.get('addr_count', 1)
 
         penalty = 0.0
         if row.type == 'w':
             penalty = 0.3
 
         penalty = 0.0
         if row.type == 'w':
             penalty = 0.3
+        elif row.type == 'W':
+            if len(row.word_token) == 1 and row.word_token == row.word:
+                penalty = 0.2 if row.word.isdigit() else 0.3
         elif row.type == 'H':
             penalty = sum(0.1 for c in row.word_token if c != ' ' and not c.isdigit())
             if all(not c.isdigit() for c in row.word_token):
                 penalty += 0.2 * (len(row.word_token) - 1)
         elif row.type == 'H':
             penalty = sum(0.1 for c in row.word_token if c != ' ' and not c.isdigit())
             if all(not c.isdigit() for c in row.word_token):
                 penalty += 0.2 * (len(row.word_token) - 1)
+        elif row.type == 'C':
+            if len(row.word_token) == 1:
+                penalty = 0.3
 
         if row.info is None:
             lookup_word = row.word
 
         if row.info is None:
             lookup_word = row.word
@@ -115,9 +122,10 @@ class ICUToken(qmod.Token):
         else:
             lookup_word = row.word_token
 
         else:
             lookup_word = row.word_token
 
-        return ICUToken(penalty=penalty, token=row.word_id, count=count,
+        return ICUToken(penalty=penalty, token=row.word_id, count=max(1, count),
                         lookup_word=lookup_word, is_indexed=True,
                         lookup_word=lookup_word, is_indexed=True,
-                        word_token=row.word_token, info=row.info)
+                        word_token=row.word_token, info=row.info,
+                        addr_count=max(1, addr_count))
 
 
 
 
 
 
@@ -133,10 +141,19 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
     async def setup(self) -> None:
         """ Set up static data structures needed for the analysis.
         """
     async def setup(self) -> None:
         """ Set up static data structures needed for the analysis.
         """
-        rules = await self.conn.get_property('tokenizer_import_normalisation')
-        self.normalizer = Transliterator.createFromRules("normalization", rules)
-        rules = await self.conn.get_property('tokenizer_import_transliteration')
-        self.transliterator = Transliterator.createFromRules("transliteration", rules)
+        async def _make_normalizer() -> Any:
+            rules = await self.conn.get_property('tokenizer_import_normalisation')
+            return Transliterator.createFromRules("normalization", rules)
+
+        self.normalizer = await self.conn.get_cached_value('ICUTOK', 'normalizer',
+                                                           _make_normalizer)
+
+        async def _make_transliterator() -> Any:
+            rules = await self.conn.get_property('tokenizer_import_transliteration')
+            return Transliterator.createFromRules("transliteration", rules)
+
+        self.transliterator = await self.conn.get_cached_value('ICUTOK', 'transliterator',
+                                                               _make_transliterator)
 
         if 'word' not in self.conn.t.meta.tables:
             sa.Table('word', self.conn.t.meta,
 
         if 'word' not in self.conn.t.meta.tables:
             sa.Table('word', self.conn.t.meta,
@@ -144,7 +161,7 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
                      sa.Column('word_token', sa.Text, nullable=False),
                      sa.Column('type', sa.Text, nullable=False),
                      sa.Column('word', sa.Text),
                      sa.Column('word_token', sa.Text, nullable=False),
                      sa.Column('type', sa.Text, nullable=False),
                      sa.Column('word', sa.Text),
-                     sa.Column('info', self.conn.t.types.Json))
+                     sa.Column('info', Json))
 
 
     async def analyze_query(self, phrases: List[qmod.Phrase]) -> qmod.QueryStruct:
 
 
     async def analyze_query(self, phrases: List[qmod.Phrase]) -> qmod.QueryStruct:
@@ -169,13 +186,12 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
                 if row.type == 'S':
                     if row.info['op'] in ('in', 'near'):
                         if trange.start == 0:
                 if row.type == 'S':
                     if row.info['op'] in ('in', 'near'):
                         if trange.start == 0:
-                            query.add_token(trange, qmod.TokenType.CATEGORY, token)
+                            query.add_token(trange, qmod.TokenType.NEAR_ITEM, token)
                     else:
                     else:
-                        query.add_token(trange, qmod.TokenType.QUALIFIER, token)
-                        if trange.start == 0 or trange.end == query.num_token_slots():
-                            token = copy(token)
-                            token.penalty += 0.1 * (query.num_token_slots())
-                            query.add_token(trange, qmod.TokenType.CATEGORY, token)
+                        if trange.start == 0 and trange.end == query.num_token_slots():
+                            query.add_token(trange, qmod.TokenType.NEAR_ITEM, token)
+                        else:
+                            query.add_token(trange, qmod.TokenType.QUALIFIER, token)
                 else:
                     query.add_token(trange, DB_TO_TOKEN_TYPE[row.type], token)
 
                 else:
                     query.add_token(trange, DB_TO_TOKEN_TYPE[row.type], token)
 
@@ -243,7 +259,7 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
             if len(part.token) <= 4 and part[0].isdigit()\
                and not node.has_tokens(i+1, qmod.TokenType.HOUSENUMBER):
                 query.add_token(qmod.TokenRange(i, i+1), qmod.TokenType.HOUSENUMBER,
             if len(part.token) <= 4 and part[0].isdigit()\
                and not node.has_tokens(i+1, qmod.TokenType.HOUSENUMBER):
                 query.add_token(qmod.TokenRange(i, i+1), qmod.TokenType.HOUSENUMBER,
-                                ICUToken(0.5, 0, 1, part.token, True, part.token, None))
+                                ICUToken(0.5, 0, 1, 1, part.token, True, part.token, None))
 
 
     def rerank_tokens(self, query: qmod.QueryStruct, parts: QueryParts) -> None:
 
 
     def rerank_tokens(self, query: qmod.QueryStruct, parts: QueryParts) -> None: