]> git.openstreetmap.org Git - nominatim.git/blobdiff - src/nominatim_api/search/icu_tokenizer.py
consistently use query module as qmod
[nominatim.git] / src / nominatim_api / search / icu_tokenizer.py
index 487dd1710354c9c9801cef9ce22033c1a5169e87..d4d0643f3a59bd46f70caa0393e41a78726e9409 100644 (file)
@@ -50,15 +50,16 @@ PENALTY_IN_TOKEN_BREAK = {
 @dataclasses.dataclass
 class QueryPart:
     """ Normalized and transliterated form of a single term in the query.
+
         When the term came out of a split during the transliteration,
         the normalized string is the full word before transliteration.
-        The word number keeps track of the word before transliteration
-        and can be used to identify partial transliterated terms.
+        Check the subsequent break type to figure out if the word is
+        continued.
+
         Penalty is the break penalty for the break following the token.
     """
     token: str
     normalized: str
-    word_number: int
     penalty: float
 
 
@@ -66,19 +67,20 @@ QueryParts = List[QueryPart]
 WordDict = Dict[str, List[qmod.TokenRange]]
 
 
-def yield_words(terms: List[QueryPart], start: int) -> Iterator[Tuple[str, qmod.TokenRange]]:
-    """ Return all combinations of words in the terms list after the
-        given position.
+def extract_words(terms: List[QueryPart], start: int,  words: WordDict) -> None:
+    """ Add all combinations of words in the terms list after the
+        given position to the word list.
     """
     total = len(terms)
+    base_penalty = PENALTY_IN_TOKEN_BREAK[qmod.BreakType.WORD]
     for first in range(start, total):
         word = terms[first].token
-        penalty = PENALTY_IN_TOKEN_BREAK[qmod.BreakType.WORD]
-        yield word, qmod.TokenRange(first, first + 1, penalty=penalty)
+        penalty = base_penalty
+        words[word].append(qmod.TokenRange(first, first + 1, penalty=penalty))
         for last in range(first + 1, min(first + 20, total)):
             word = ' '.join((word, terms[last].token))
             penalty += terms[last - 1].penalty
-            yield word, qmod.TokenRange(first, last + 1, penalty=penalty)
+            words[word].append(qmod.TokenRange(first, last + 1, penalty=penalty))
 
 
 @dataclasses.dataclass
@@ -255,8 +257,7 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
         """
         parts: QueryParts = []
         phrase_start = 0
-        words = defaultdict(list)
-        wordnr = 0
+        words: WordDict = defaultdict(list)
         for phrase in query.source:
             query.nodes[-1].ptype = phrase.ptype
             phrase_split = re.split('([ :-])', phrase.text)
@@ -271,15 +272,13 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
                 if trans:
                     for term in trans.split(' '):
                         if term:
-                            parts.append(QueryPart(term, word, wordnr,
+                            parts.append(QueryPart(term, word,
                                                    PENALTY_IN_TOKEN_BREAK[qmod.BreakType.TOKEN]))
                             query.add_node(qmod.BreakType.TOKEN, phrase.ptype)
                     query.nodes[-1].btype = qmod.BreakType(breakchar)
                     parts[-1].penalty = PENALTY_IN_TOKEN_BREAK[qmod.BreakType(breakchar)]
-                wordnr += 1
 
-            for word, wrange in yield_words(parts, phrase_start):
-                words[word].append(wrange)
+            extract_words(parts, phrase_start, words)
 
             phrase_start = len(parts)
         query.nodes[-1].btype = qmod.BreakType.END
@@ -323,7 +322,7 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
             elif tlist.ttype not in (qmod.TokenType.COUNTRY, qmod.TokenType.PARTIAL):
                 norm = parts[i].normalized
                 for j in range(i + 1, tlist.end):
-                    if parts[j - 1].word_number != parts[j].word_number:
+                    if node.btype != qmod.BreakType.TOKEN:
                         norm += '  ' + parts[j].normalized
                 for token in tlist.tokens:
                     cast(ICUToken, token).rematch(norm)