]> git.openstreetmap.org Git - nominatim.git/blobdiff - nominatim/api/search/icu_tokenizer.py
prefer min() function over if construct
[nominatim.git] / nominatim / api / search / icu_tokenizer.py
index 14698a28867ca7ae0fc783f6b6e11385ffe45d8a..1c2565d1ad60c80df1f1ecb78b216439b8d98224 100644 (file)
@@ -8,7 +8,6 @@
 Implementation of query analysis for the ICU tokenizer.
 """
 from typing import Tuple, Dict, List, Optional, NamedTuple, Iterator, Any, cast
 Implementation of query analysis for the ICU tokenizer.
 """
 from typing import Tuple, Dict, List, Optional, NamedTuple, Iterator, Any, cast
-from copy import copy
 from collections import defaultdict
 import dataclasses
 import difflib
 from collections import defaultdict
 import dataclasses
 import difflib
@@ -21,10 +20,8 @@ from nominatim.typing import SaRow
 from nominatim.api.connection import SearchConnection
 from nominatim.api.logging import log
 from nominatim.api.search import query as qmod
 from nominatim.api.connection import SearchConnection
 from nominatim.api.logging import log
 from nominatim.api.search import query as qmod
-
-# XXX: TODO
-class AbstractQueryAnalyzer:
-    pass
+from nominatim.api.search.query_analyzer_factory import AbstractQueryAnalyzer
+from nominatim.db.sqlalchemy_types import Json
 
 
 DB_TO_TOKEN_TYPE = {
 
 
 DB_TO_TOKEN_TYPE = {
@@ -86,7 +83,7 @@ class ICUToken(qmod.Token):
         seq = difflib.SequenceMatcher(a=self.lookup_word, b=norm)
         distance = 0
         for tag, afrom, ato, bfrom, bto in seq.get_opcodes():
         seq = difflib.SequenceMatcher(a=self.lookup_word, b=norm)
         distance = 0
         for tag, afrom, ato, bfrom, bto in seq.get_opcodes():
-            if tag == 'delete' and (afrom == 0 or ato == len(self.lookup_word)):
+            if tag in ('delete', 'insert') and (afrom == 0 or ato == len(self.lookup_word)):
                 distance += 1
             elif tag == 'replace':
                 distance += max((ato-afrom), (bto-bfrom))
                 distance += 1
             elif tag == 'replace':
                 distance += max((ato-afrom), (bto-bfrom))
@@ -104,10 +101,16 @@ class ICUToken(qmod.Token):
         penalty = 0.0
         if row.type == 'w':
             penalty = 0.3
         penalty = 0.0
         if row.type == 'w':
             penalty = 0.3
+        elif row.type == 'W':
+            if len(row.word_token) == 1 and row.word_token == row.word:
+                penalty = 0.2 if row.word.isdigit() else 0.3
         elif row.type == 'H':
             penalty = sum(0.1 for c in row.word_token if c != ' ' and not c.isdigit())
             if all(not c.isdigit() for c in row.word_token):
                 penalty += 0.2 * (len(row.word_token) - 1)
         elif row.type == 'H':
             penalty = sum(0.1 for c in row.word_token if c != ' ' and not c.isdigit())
             if all(not c.isdigit() for c in row.word_token):
                 penalty += 0.2 * (len(row.word_token) - 1)
+        elif row.type == 'C':
+            if len(row.word_token) == 1:
+                penalty = 0.3
 
         if row.info is None:
             lookup_word = row.word
 
         if row.info is None:
             lookup_word = row.word
@@ -136,10 +139,19 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
     async def setup(self) -> None:
         """ Set up static data structures needed for the analysis.
         """
     async def setup(self) -> None:
         """ Set up static data structures needed for the analysis.
         """
-        rules = await self.conn.get_property('tokenizer_import_normalisation')
-        self.normalizer = Transliterator.createFromRules("normalization", rules)
-        rules = await self.conn.get_property('tokenizer_import_transliteration')
-        self.transliterator = Transliterator.createFromRules("transliteration", rules)
+        async def _make_normalizer() -> Any:
+            rules = await self.conn.get_property('tokenizer_import_normalisation')
+            return Transliterator.createFromRules("normalization", rules)
+
+        self.normalizer = await self.conn.get_cached_value('ICUTOK', 'normalizer',
+                                                           _make_normalizer)
+
+        async def _make_transliterator() -> Any:
+            rules = await self.conn.get_property('tokenizer_import_transliteration')
+            return Transliterator.createFromRules("transliteration", rules)
+
+        self.transliterator = await self.conn.get_cached_value('ICUTOK', 'transliterator',
+                                                               _make_transliterator)
 
         if 'word' not in self.conn.t.meta.tables:
             sa.Table('word', self.conn.t.meta,
 
         if 'word' not in self.conn.t.meta.tables:
             sa.Table('word', self.conn.t.meta,
@@ -147,7 +159,7 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
                      sa.Column('word_token', sa.Text, nullable=False),
                      sa.Column('type', sa.Text, nullable=False),
                      sa.Column('word', sa.Text),
                      sa.Column('word_token', sa.Text, nullable=False),
                      sa.Column('type', sa.Text, nullable=False),
                      sa.Column('word', sa.Text),
-                     sa.Column('info', self.conn.t.types.Json))
+                     sa.Column('info', Json))
 
 
     async def analyze_query(self, phrases: List[qmod.Phrase]) -> qmod.QueryStruct:
 
 
     async def analyze_query(self, phrases: List[qmod.Phrase]) -> qmod.QueryStruct:
@@ -156,7 +168,7 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
         """
         log().section('Analyze query (using ICU tokenizer)')
         normalized = list(filter(lambda p: p.text,
         """
         log().section('Analyze query (using ICU tokenizer)')
         normalized = list(filter(lambda p: p.text,
-                                 (qmod.Phrase(p.ptype, self.normalizer.transliterate(p.text))
+                                 (qmod.Phrase(p.ptype, self.normalize_text(p.text))
                                   for p in phrases)))
         query = qmod.QueryStruct(normalized)
         log().var_dump('Normalized query', query.source)
                                   for p in phrases)))
         query = qmod.QueryStruct(normalized)
         log().var_dump('Normalized query', query.source)
@@ -172,13 +184,12 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
                 if row.type == 'S':
                     if row.info['op'] in ('in', 'near'):
                         if trange.start == 0:
                 if row.type == 'S':
                     if row.info['op'] in ('in', 'near'):
                         if trange.start == 0:
-                            query.add_token(trange, qmod.TokenType.CATEGORY, token)
+                            query.add_token(trange, qmod.TokenType.NEAR_ITEM, token)
                     else:
                     else:
-                        query.add_token(trange, qmod.TokenType.QUALIFIER, token)
-                        if trange.start == 0 or trange.end == query.num_token_slots():
-                            token = copy(token)
-                            token.penalty += 0.1 * (query.num_token_slots())
-                            query.add_token(trange, qmod.TokenType.CATEGORY, token)
+                        if trange.start == 0 and trange.end == query.num_token_slots():
+                            query.add_token(trange, qmod.TokenType.NEAR_ITEM, token)
+                        else:
+                            query.add_token(trange, qmod.TokenType.QUALIFIER, token)
                 else:
                     query.add_token(trange, DB_TO_TOKEN_TYPE[row.type], token)
 
                 else:
                     query.add_token(trange, DB_TO_TOKEN_TYPE[row.type], token)
 
@@ -190,6 +201,14 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
         return query
 
 
         return query
 
 
+    def normalize_text(self, text: str) -> str:
+        """ Bring the given text into a normalized form. That is the
+            standardized form search will work with. All information removed
+            at this stage is inevitably lost.
+        """
+        return cast(str, self.normalizer.transliterate(text))
+
+
     def split_query(self, query: qmod.QueryStruct) -> Tuple[QueryParts, WordDict]:
         """ Transliterate the phrases and split them into tokens.
 
     def split_query(self, query: qmod.QueryStruct) -> Tuple[QueryParts, WordDict]:
         """ Transliterate the phrases and split them into tokens.
 
@@ -251,12 +270,11 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
                        and (repl.ttype != qmod.TokenType.HOUSENUMBER
                             or len(tlist.tokens[0].lookup_word) > 4):
                         repl.add_penalty(0.39)
                        and (repl.ttype != qmod.TokenType.HOUSENUMBER
                             or len(tlist.tokens[0].lookup_word) > 4):
                         repl.add_penalty(0.39)
-            elif tlist.ttype == qmod.TokenType.HOUSENUMBER:
+            elif tlist.ttype == qmod.TokenType.HOUSENUMBER \
+                 and len(tlist.tokens[0].lookup_word) <= 3:
                 if any(c.isdigit() for c in tlist.tokens[0].lookup_word):
                     for repl in node.starting:
                 if any(c.isdigit() for c in tlist.tokens[0].lookup_word):
                     for repl in node.starting:
-                        if repl.end == tlist.end and repl.ttype != qmod.TokenType.HOUSENUMBER \
-                           and (repl.ttype != qmod.TokenType.HOUSENUMBER
-                                or len(tlist.tokens[0].lookup_word) <= 3):
+                        if repl.end == tlist.end and repl.ttype != qmod.TokenType.HOUSENUMBER:
                             repl.add_penalty(0.5 - tlist.tokens[0].penalty)
             elif tlist.ttype not in (qmod.TokenType.COUNTRY, qmod.TokenType.PARTIAL):
                 norm = parts[i].normalized
                             repl.add_penalty(0.5 - tlist.tokens[0].penalty)
             elif tlist.ttype not in (qmod.TokenType.COUNTRY, qmod.TokenType.PARTIAL):
                 norm = parts[i].normalized