]> git.openstreetmap.org Git - nominatim.git/blobdiff - src/nominatim_api/search/icu_tokenizer.py
Merge pull request #3587 from danieldegroot2/lookup-spelling
[nominatim.git] / src / nominatim_api / search / icu_tokenizer.py
index 426656797c873b2a4419c3913d4ac3f66d5c4610..fa14531aed0d6c07cf79c277255324495b1b063d 100644 (file)
@@ -16,8 +16,8 @@ from icu import Transliterator
 
 import sqlalchemy as sa
 
 
 import sqlalchemy as sa
 
-from nominatim_core.typing import SaRow
-from nominatim_core.db.sqlalchemy_types import Json
+from ..typing import SaRow
+from ..sql.sqlalchemy_types import Json
 from ..connection import SearchConnection
 from ..logging import log
 from ..search import query as qmod
 from ..connection import SearchConnection
 from ..logging import log
 from ..search import query as qmod
@@ -48,6 +48,7 @@ class QueryPart(NamedTuple):
 QueryParts = List[QueryPart]
 WordDict = Dict[str, List[qmod.TokenRange]]
 
 QueryParts = List[QueryPart]
 WordDict = Dict[str, List[qmod.TokenRange]]
 
+
 def yield_words(terms: List[QueryPart], start: int) -> Iterator[Tuple[str, qmod.TokenRange]]:
     """ Return all combinations of words in the terms list after the
         given position.
 def yield_words(terms: List[QueryPart], start: int) -> Iterator[Tuple[str, qmod.TokenRange]]:
     """ Return all combinations of words in the terms list after the
         given position.
@@ -72,7 +73,6 @@ class ICUToken(qmod.Token):
         assert self.info
         return self.info.get('class', ''), self.info.get('type', '')
 
         assert self.info
         return self.info.get('class', ''), self.info.get('type', '')
 
-
     def rematch(self, norm: str) -> None:
         """ Check how well the token matches the given normalized string
             and add a penalty, if necessary.
     def rematch(self, norm: str) -> None:
         """ Check how well the token matches the given normalized string
             and add a penalty, if necessary.
@@ -91,7 +91,6 @@ class ICUToken(qmod.Token):
                 distance += abs((ato-afrom) - (bto-bfrom))
         self.penalty += (distance/len(self.lookup_word))
 
                 distance += abs((ato-afrom) - (bto-bfrom))
         self.penalty += (distance/len(self.lookup_word))
 
-
     @staticmethod
     def from_db_row(row: SaRow) -> 'ICUToken':
         """ Create a ICUToken from the row of the word table.
     @staticmethod
     def from_db_row(row: SaRow) -> 'ICUToken':
         """ Create a ICUToken from the row of the word table.
@@ -123,21 +122,18 @@ class ICUToken(qmod.Token):
             lookup_word = row.word_token
 
         return ICUToken(penalty=penalty, token=row.word_id, count=max(1, count),
             lookup_word = row.word_token
 
         return ICUToken(penalty=penalty, token=row.word_id, count=max(1, count),
-                        lookup_word=lookup_word, is_indexed=True,
+                        lookup_word=lookup_word,
                         word_token=row.word_token, info=row.info,
                         addr_count=max(1, addr_count))
 
 
                         word_token=row.word_token, info=row.info,
                         addr_count=max(1, addr_count))
 
 
-
 class ICUQueryAnalyzer(AbstractQueryAnalyzer):
     """ Converter for query strings into a tokenized query
         using the tokens created by a ICU tokenizer.
     """
 class ICUQueryAnalyzer(AbstractQueryAnalyzer):
     """ Converter for query strings into a tokenized query
         using the tokens created by a ICU tokenizer.
     """
-
     def __init__(self, conn: SearchConnection) -> None:
         self.conn = conn
 
     def __init__(self, conn: SearchConnection) -> None:
         self.conn = conn
 
-
     async def setup(self) -> None:
         """ Set up static data structures needed for the analysis.
         """
     async def setup(self) -> None:
         """ Set up static data structures needed for the analysis.
         """
@@ -163,7 +159,6 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
                      sa.Column('word', sa.Text),
                      sa.Column('info', Json))
 
                      sa.Column('word', sa.Text),
                      sa.Column('info', Json))
 
-
     async def analyze_query(self, phrases: List[qmod.Phrase]) -> qmod.QueryStruct:
         """ Analyze the given list of phrases and return the
             tokenized query.
     async def analyze_query(self, phrases: List[qmod.Phrase]) -> qmod.QueryStruct:
         """ Analyze the given list of phrases and return the
             tokenized query.
@@ -202,7 +197,6 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
 
         return query
 
 
         return query
 
-
     def normalize_text(self, text: str) -> str:
         """ Bring the given text into a normalized form. That is the
             standardized form search will work with. All information removed
     def normalize_text(self, text: str) -> str:
         """ Bring the given text into a normalized form. That is the
             standardized form search will work with. All information removed
@@ -210,7 +204,6 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
         """
         return cast(str, self.normalizer.transliterate(text))
 
         """
         return cast(str, self.normalizer.transliterate(text))
 
-
     def split_query(self, query: qmod.QueryStruct) -> Tuple[QueryParts, WordDict]:
         """ Transliterate the phrases and split them into tokens.
 
     def split_query(self, query: qmod.QueryStruct) -> Tuple[QueryParts, WordDict]:
         """ Transliterate the phrases and split them into tokens.
 
@@ -243,7 +236,6 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
 
         return parts, words
 
 
         return parts, words
 
-
     async def lookup_in_db(self, words: List[str]) -> 'sa.Result[Any]':
         """ Return the token information from the database for the
             given word tokens.
     async def lookup_in_db(self, words: List[str]) -> 'sa.Result[Any]':
         """ Return the token information from the database for the
             given word tokens.
@@ -251,7 +243,6 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
         t = self.conn.t.meta.tables['word']
         return await self.conn.execute(t.select().where(t.c.word_token.in_(words)))
 
         t = self.conn.t.meta.tables['word']
         return await self.conn.execute(t.select().where(t.c.word_token.in_(words)))
 
-
     def add_extra_tokens(self, query: qmod.QueryStruct, parts: QueryParts) -> None:
         """ Add tokens to query that are not saved in the database.
         """
     def add_extra_tokens(self, query: qmod.QueryStruct, parts: QueryParts) -> None:
         """ Add tokens to query that are not saved in the database.
         """
@@ -259,8 +250,9 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
             if len(part.token) <= 4 and part[0].isdigit()\
                and not node.has_tokens(i+1, qmod.TokenType.HOUSENUMBER):
                 query.add_token(qmod.TokenRange(i, i+1), qmod.TokenType.HOUSENUMBER,
             if len(part.token) <= 4 and part[0].isdigit()\
                and not node.has_tokens(i+1, qmod.TokenType.HOUSENUMBER):
                 query.add_token(qmod.TokenRange(i, i+1), qmod.TokenType.HOUSENUMBER,
-                                ICUToken(0.5, 0, 1, 1, part.token, True, part.token, None))
-
+                                ICUToken(penalty=0.5, token=0,
+                                         count=1, addr_count=1, lookup_word=part.token,
+                                         word_token=part.token, info=None))
 
     def rerank_tokens(self, query: qmod.QueryStruct, parts: QueryParts) -> None:
         """ Add penalties to tokens that depend on presence of other token.
 
     def rerank_tokens(self, query: qmod.QueryStruct, parts: QueryParts) -> None:
         """ Add penalties to tokens that depend on presence of other token.
@@ -272,8 +264,8 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
                        and (repl.ttype != qmod.TokenType.HOUSENUMBER
                             or len(tlist.tokens[0].lookup_word) > 4):
                         repl.add_penalty(0.39)
                        and (repl.ttype != qmod.TokenType.HOUSENUMBER
                             or len(tlist.tokens[0].lookup_word) > 4):
                         repl.add_penalty(0.39)
-            elif tlist.ttype == qmod.TokenType.HOUSENUMBER \
-                 and len(tlist.tokens[0].lookup_word) <= 3:
+            elif (tlist.ttype == qmod.TokenType.HOUSENUMBER
+                  and len(tlist.tokens[0].lookup_word) <= 3):
                 if any(c.isdigit() for c in tlist.tokens[0].lookup_word):
                     for repl in node.starting:
                         if repl.end == tlist.end and repl.ttype != qmod.TokenType.HOUSENUMBER:
                 if any(c.isdigit() for c in tlist.tokens[0].lookup_word):
                     for repl in node.starting:
                         if repl.end == tlist.end and repl.ttype != qmod.TokenType.HOUSENUMBER: