]> git.openstreetmap.org Git - nominatim.git/blobdiff - src/nominatim_api/search/icu_tokenizer.py
Merge remote-tracking branch 'upstream/master'
[nominatim.git] / src / nominatim_api / search / icu_tokenizer.py
index 8f2069c1a8bc57212956d1ce64d3a5914c12f920..3b85f26df7f1eea53323851e46d5479450d48342 100644 (file)
@@ -29,36 +29,37 @@ from .query_analyzer_factory import AbstractQueryAnalyzer
 
 
 DB_TO_TOKEN_TYPE = {
-    'W': qmod.TokenType.WORD,
-    'w': qmod.TokenType.PARTIAL,
-    'H': qmod.TokenType.HOUSENUMBER,
-    'P': qmod.TokenType.POSTCODE,
-    'C': qmod.TokenType.COUNTRY
+    'W': qmod.TOKEN_WORD,
+    'w': qmod.TOKEN_PARTIAL,
+    'H': qmod.TOKEN_HOUSENUMBER,
+    'P': qmod.TOKEN_POSTCODE,
+    'C': qmod.TOKEN_COUNTRY
 }
 
 PENALTY_IN_TOKEN_BREAK = {
-     qmod.BreakType.START: 0.5,
-     qmod.BreakType.END: 0.5,
-     qmod.BreakType.PHRASE: 0.5,
-     qmod.BreakType.SOFT_PHRASE: 0.5,
-     qmod.BreakType.WORD: 0.1,
-     qmod.BreakType.PART: 0.0,
-     qmod.BreakType.TOKEN: 0.0
+     qmod.BREAK_START: 0.5,
+     qmod.BREAK_END: 0.5,
+     qmod.BREAK_PHRASE: 0.5,
+     qmod.BREAK_SOFT_PHRASE: 0.5,
+     qmod.BREAK_WORD: 0.1,
+     qmod.BREAK_PART: 0.0,
+     qmod.BREAK_TOKEN: 0.0
 }
 
 
 @dataclasses.dataclass
 class QueryPart:
     """ Normalized and transliterated form of a single term in the query.
+
         When the term came out of a split during the transliteration,
         the normalized string is the full word before transliteration.
-        The word number keeps track of the word before transliteration
-        and can be used to identify partial transliterated terms.
+        Check the subsequent break type to figure out if the word is
+        continued.
+
         Penalty is the break penalty for the break following the token.
     """
     token: str
     normalized: str
-    word_number: int
     penalty: float
 
 
@@ -66,19 +67,20 @@ QueryParts = List[QueryPart]
 WordDict = Dict[str, List[qmod.TokenRange]]
 
 
-def yield_words(terms: List[QueryPart], start: int) -> Iterator[Tuple[str, qmod.TokenRange]]:
-    """ Return all combinations of words in the terms list after the
-        given position.
+def extract_words(terms: List[QueryPart], start: int,  words: WordDict) -> None:
+    """ Add all combinations of words in the terms list after the
+        given position to the word list.
     """
     total = len(terms)
+    base_penalty = PENALTY_IN_TOKEN_BREAK[qmod.BREAK_WORD]
     for first in range(start, total):
         word = terms[first].token
-        penalty = PENALTY_IN_TOKEN_BREAK[qmod.BreakType.WORD]
-        yield word, qmod.TokenRange(first, first + 1, penalty=penalty)
+        penalty = base_penalty
+        words[word].append(qmod.TokenRange(first, first + 1, penalty=penalty))
         for last in range(first + 1, min(first + 20, total)):
             word = ' '.join((word, terms[last].token))
             penalty += terms[last - 1].penalty
-            yield word, qmod.TokenRange(first, last + 1, penalty=penalty)
+            words[word].append(qmod.TokenRange(first, last + 1, penalty=penalty))
 
 
 @dataclasses.dataclass
@@ -229,12 +231,12 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
                 if row.type == 'S':
                     if row.info['op'] in ('in', 'near'):
                         if trange.start == 0:
-                            query.add_token(trange, qmod.TokenType.NEAR_ITEM, token)
+                            query.add_token(trange, qmod.TOKEN_NEAR_ITEM, token)
                     else:
                         if trange.start == 0 and trange.end == query.num_token_slots():
-                            query.add_token(trange, qmod.TokenType.NEAR_ITEM, token)
+                            query.add_token(trange, qmod.TOKEN_NEAR_ITEM, token)
                         else:
-                            query.add_token(trange, qmod.TokenType.QUALIFIER, token)
+                            query.add_token(trange, qmod.TOKEN_QUALIFIER, token)
                 else:
                     query.add_token(trange, DB_TO_TOKEN_TYPE[row.type], token)
 
@@ -250,7 +252,7 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
             standardized form search will work with. All information removed
             at this stage is inevitably lost.
         """
-        return cast(str, self.normalizer.transliterate(text))
+        return cast(str, self.normalizer.transliterate(text)).strip('-: ')
 
     def split_query(self, query: qmod.QueryStruct) -> Tuple[QueryParts, WordDict]:
         """ Transliterate the phrases and split them into tokens.
@@ -261,8 +263,7 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
         """
         parts: QueryParts = []
         phrase_start = 0
-        words = defaultdict(list)
-        wordnr = 0
+        words: WordDict = defaultdict(list)
         for phrase in query.source:
             query.nodes[-1].ptype = phrase.ptype
             phrase_split = re.split('([ :-])', phrase.text)
@@ -277,18 +278,16 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
                 if trans:
                     for term in trans.split(' '):
                         if term:
-                            parts.append(QueryPart(term, word, wordnr,
-                                                   PENALTY_IN_TOKEN_BREAK[qmod.BreakType.TOKEN]))
-                            query.add_node(qmod.BreakType.TOKEN, phrase.ptype)
-                    query.nodes[-1].btype = qmod.BreakType(breakchar)
-                    parts[-1].penalty = PENALTY_IN_TOKEN_BREAK[qmod.BreakType(breakchar)]
-                wordnr += 1
+                            parts.append(QueryPart(term, word,
+                                                   PENALTY_IN_TOKEN_BREAK[qmod.BREAK_TOKEN]))
+                            query.add_node(qmod.BREAK_TOKEN, phrase.ptype)
+                    query.nodes[-1].btype = breakchar
+                    parts[-1].penalty = PENALTY_IN_TOKEN_BREAK[breakchar]
 
-            for word, wrange in yield_words(parts, phrase_start):
-                words[word].append(wrange)
+            extract_words(parts, phrase_start, words)
 
             phrase_start = len(parts)
-        query.nodes[-1].btype = qmod.BreakType.END
+        query.nodes[-1].btype = qmod.BREAK_END
 
         return parts, words
 
@@ -304,8 +303,8 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
         """
         for part, node, i in zip(parts, query.nodes, range(1000)):
             if len(part.token) <= 4 and part.token.isdigit()\
-               and not node.has_tokens(i+1, qmod.TokenType.HOUSENUMBER):
-                query.add_token(qmod.TokenRange(i, i+1), qmod.TokenType.HOUSENUMBER,
+               and not node.has_tokens(i+1, qmod.TOKEN_HOUSENUMBER):
+                query.add_token(qmod.TokenRange(i, i+1), qmod.TOKEN_HOUSENUMBER,
                                 ICUToken(penalty=0.5, token=0,
                                          count=1, addr_count=1, lookup_word=part.token,
                                          word_token=part.token, info=None))
@@ -314,31 +313,31 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
         """ Add penalties to tokens that depend on presence of other token.
         """
         for i, node, tlist in query.iter_token_lists():
-            if tlist.ttype == qmod.TokenType.POSTCODE:
+            if tlist.ttype == qmod.TOKEN_POSTCODE:
                 for repl in node.starting:
-                    if repl.end == tlist.end and repl.ttype != qmod.TokenType.POSTCODE \
-                       and (repl.ttype != qmod.TokenType.HOUSENUMBER
+                    if repl.end == tlist.end and repl.ttype != qmod.TOKEN_POSTCODE \
+                       and (repl.ttype != qmod.TOKEN_HOUSENUMBER
                             or len(tlist.tokens[0].lookup_word) > 4):
                         repl.add_penalty(0.39)
-            elif (tlist.ttype == qmod.TokenType.HOUSENUMBER
+            elif (tlist.ttype == qmod.TOKEN_HOUSENUMBER
                   and len(tlist.tokens[0].lookup_word) <= 3):
                 if any(c.isdigit() for c in tlist.tokens[0].lookup_word):
                     for repl in node.starting:
-                        if repl.end == tlist.end and repl.ttype != qmod.TokenType.HOUSENUMBER:
+                        if repl.end == tlist.end and repl.ttype != qmod.TOKEN_HOUSENUMBER:
                             repl.add_penalty(0.5 - tlist.tokens[0].penalty)
-            elif tlist.ttype not in (qmod.TokenType.COUNTRY, qmod.TokenType.PARTIAL):
+            elif tlist.ttype not in (qmod.TOKEN_COUNTRY, qmod.TOKEN_PARTIAL):
                 norm = parts[i].normalized
                 for j in range(i + 1, tlist.end):
-                    if parts[j - 1].word_number != parts[j].word_number:
+                    if node.btype != qmod.BREAK_TOKEN:
                         norm += '  ' + parts[j].normalized
                 for token in tlist.tokens:
                     cast(ICUToken, token).rematch(norm)
 
 
 def _dump_transliterated(query: qmod.QueryStruct, parts: QueryParts) -> str:
-    out = query.nodes[0].btype.value
+    out = query.nodes[0].btype
     for node, part in zip(query.nodes[1:], parts):
-        out += part.token + node.btype.value
+        out += part.token + node.btype
     return out
 
 
@@ -348,7 +347,7 @@ def _dump_word_tokens(query: qmod.QueryStruct) -> Iterator[List[Any]]:
         for tlist in node.starting:
             for token in tlist.tokens:
                 t = cast(ICUToken, token)
-                yield [tlist.ttype.name, t.token, t.word_token or '',
+                yield [tlist.ttype, t.token, t.word_token or '',
                        t.lookup_word or '', t.penalty, t.count, t.info]