]> git.openstreetmap.org Git - nominatim.git/blobdiff - nominatim/api/search/icu_tokenizer.py
don't even try heavily penalized searches
[nominatim.git] / nominatim / api / search / icu_tokenizer.py
index ad08294e00f8409c04043d9187a96198aebc8e16..23cfa5a166c003a1b5638f0334d10636a335d935 100644 (file)
@@ -8,7 +8,6 @@
 Implementation of query analysis for the ICU tokenizer.
 """
 from typing import Tuple, Dict, List, Optional, NamedTuple, Iterator, Any, cast
 Implementation of query analysis for the ICU tokenizer.
 """
 from typing import Tuple, Dict, List, Optional, NamedTuple, Iterator, Any, cast
-from copy import copy
 from collections import defaultdict
 import dataclasses
 import difflib
 from collections import defaultdict
 import dataclasses
 import difflib
@@ -22,6 +21,7 @@ from nominatim.api.connection import SearchConnection
 from nominatim.api.logging import log
 from nominatim.api.search import query as qmod
 from nominatim.api.search.query_analyzer_factory import AbstractQueryAnalyzer
 from nominatim.api.logging import log
 from nominatim.api.search import query as qmod
 from nominatim.api.search.query_analyzer_factory import AbstractQueryAnalyzer
+from nominatim.db.sqlalchemy_types import Json
 
 
 DB_TO_TOKEN_TYPE = {
 
 
 DB_TO_TOKEN_TYPE = {
@@ -83,7 +83,7 @@ class ICUToken(qmod.Token):
         seq = difflib.SequenceMatcher(a=self.lookup_word, b=norm)
         distance = 0
         for tag, afrom, ato, bfrom, bto in seq.get_opcodes():
         seq = difflib.SequenceMatcher(a=self.lookup_word, b=norm)
         distance = 0
         for tag, afrom, ato, bfrom, bto in seq.get_opcodes():
-            if tag == 'delete' and (afrom == 0 or ato == len(self.lookup_word)):
+            if tag in ('delete', 'insert') and (afrom == 0 or ato == len(self.lookup_word)):
                 distance += 1
             elif tag == 'replace':
                 distance += max((ato-afrom), (bto-bfrom))
                 distance += 1
             elif tag == 'replace':
                 distance += max((ato-afrom), (bto-bfrom))
@@ -97,14 +97,21 @@ class ICUToken(qmod.Token):
         """ Create a ICUToken from the row of the word table.
         """
         count = 1 if row.info is None else row.info.get('count', 1)
         """ Create a ICUToken from the row of the word table.
         """
         count = 1 if row.info is None else row.info.get('count', 1)
+        addr_count = 1 if row.info is None else row.info.get('addr_count', 1)
 
         penalty = 0.0
         if row.type == 'w':
             penalty = 0.3
 
         penalty = 0.0
         if row.type == 'w':
             penalty = 0.3
+        elif row.type == 'W':
+            if len(row.word_token) == 1 and row.word_token == row.word:
+                penalty = 0.2 if row.word.isdigit() else 0.3
         elif row.type == 'H':
             penalty = sum(0.1 for c in row.word_token if c != ' ' and not c.isdigit())
             if all(not c.isdigit() for c in row.word_token):
                 penalty += 0.2 * (len(row.word_token) - 1)
         elif row.type == 'H':
             penalty = sum(0.1 for c in row.word_token if c != ' ' and not c.isdigit())
             if all(not c.isdigit() for c in row.word_token):
                 penalty += 0.2 * (len(row.word_token) - 1)
+        elif row.type == 'C':
+            if len(row.word_token) == 1:
+                penalty = 0.3
 
         if row.info is None:
             lookup_word = row.word
 
         if row.info is None:
             lookup_word = row.word
@@ -117,7 +124,8 @@ class ICUToken(qmod.Token):
 
         return ICUToken(penalty=penalty, token=row.word_id, count=count,
                         lookup_word=lookup_word, is_indexed=True,
 
         return ICUToken(penalty=penalty, token=row.word_id, count=count,
                         lookup_word=lookup_word, is_indexed=True,
-                        word_token=row.word_token, info=row.info)
+                        word_token=row.word_token, info=row.info,
+                        addr_count=addr_count)
 
 
 
 
 
 
@@ -133,10 +141,19 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
     async def setup(self) -> None:
         """ Set up static data structures needed for the analysis.
         """
     async def setup(self) -> None:
         """ Set up static data structures needed for the analysis.
         """
-        rules = await self.conn.get_property('tokenizer_import_normalisation')
-        self.normalizer = Transliterator.createFromRules("normalization", rules)
-        rules = await self.conn.get_property('tokenizer_import_transliteration')
-        self.transliterator = Transliterator.createFromRules("transliteration", rules)
+        async def _make_normalizer() -> Any:
+            rules = await self.conn.get_property('tokenizer_import_normalisation')
+            return Transliterator.createFromRules("normalization", rules)
+
+        self.normalizer = await self.conn.get_cached_value('ICUTOK', 'normalizer',
+                                                           _make_normalizer)
+
+        async def _make_transliterator() -> Any:
+            rules = await self.conn.get_property('tokenizer_import_transliteration')
+            return Transliterator.createFromRules("transliteration", rules)
+
+        self.transliterator = await self.conn.get_cached_value('ICUTOK', 'transliterator',
+                                                               _make_transliterator)
 
         if 'word' not in self.conn.t.meta.tables:
             sa.Table('word', self.conn.t.meta,
 
         if 'word' not in self.conn.t.meta.tables:
             sa.Table('word', self.conn.t.meta,
@@ -144,7 +161,7 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
                      sa.Column('word_token', sa.Text, nullable=False),
                      sa.Column('type', sa.Text, nullable=False),
                      sa.Column('word', sa.Text),
                      sa.Column('word_token', sa.Text, nullable=False),
                      sa.Column('type', sa.Text, nullable=False),
                      sa.Column('word', sa.Text),
-                     sa.Column('info', self.conn.t.types.Json))
+                     sa.Column('info', Json))
 
 
     async def analyze_query(self, phrases: List[qmod.Phrase]) -> qmod.QueryStruct:
 
 
     async def analyze_query(self, phrases: List[qmod.Phrase]) -> qmod.QueryStruct:
@@ -169,13 +186,12 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
                 if row.type == 'S':
                     if row.info['op'] in ('in', 'near'):
                         if trange.start == 0:
                 if row.type == 'S':
                     if row.info['op'] in ('in', 'near'):
                         if trange.start == 0:
-                            query.add_token(trange, qmod.TokenType.CATEGORY, token)
+                            query.add_token(trange, qmod.TokenType.NEAR_ITEM, token)
                     else:
                     else:
-                        query.add_token(trange, qmod.TokenType.QUALIFIER, token)
-                        if trange.start == 0 or trange.end == query.num_token_slots():
-                            token = copy(token)
-                            token.penalty += 0.1 * (query.num_token_slots())
-                            query.add_token(trange, qmod.TokenType.CATEGORY, token)
+                        if trange.start == 0 and trange.end == query.num_token_slots():
+                            query.add_token(trange, qmod.TokenType.NEAR_ITEM, token)
+                        else:
+                            query.add_token(trange, qmod.TokenType.QUALIFIER, token)
                 else:
                     query.add_token(trange, DB_TO_TOKEN_TYPE[row.type], token)
 
                 else:
                     query.add_token(trange, DB_TO_TOKEN_TYPE[row.type], token)
 
@@ -248,7 +264,7 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
             if len(part.token) <= 4 and part[0].isdigit()\
                and not node.has_tokens(i+1, qmod.TokenType.HOUSENUMBER):
                 query.add_token(qmod.TokenRange(i, i+1), qmod.TokenType.HOUSENUMBER,
             if len(part.token) <= 4 and part[0].isdigit()\
                and not node.has_tokens(i+1, qmod.TokenType.HOUSENUMBER):
                 query.add_token(qmod.TokenRange(i, i+1), qmod.TokenType.HOUSENUMBER,
-                                ICUToken(0.5, 0, 1, part.token, True, part.token, None))
+                                ICUToken(0.5, 0, 1, 1, part.token, True, part.token, None))
 
 
     def rerank_tokens(self, query: qmod.QueryStruct, parts: QueryParts) -> None:
 
 
     def rerank_tokens(self, query: qmod.QueryStruct, parts: QueryParts) -> None: