]> git.openstreetmap.org Git - nominatim.git/commitdiff
make word generation from query a class method
authorSarah Hoffmann <lonvia@denofr.de>
Wed, 26 Feb 2025 16:22:14 +0000 (17:22 +0100)
committerSarah Hoffmann <lonvia@denofr.de>
Tue, 4 Mar 2025 07:57:37 +0000 (08:57 +0100)
src/nominatim_api/search/icu_tokenizer.py
src/nominatim_api/search/query.py
test/python/api/search/test_query.py

index 60e712d59b1af8ce2c0daebc95ce683c948ccd35..e6bba95c6fd071d871ac1813265b734f3d6fdd13 100644 (file)
@@ -8,7 +8,6 @@
 Implementation of query analysis for the ICU tokenizer.
 """
 from typing import Tuple, Dict, List, Optional, Iterator, Any, cast
 Implementation of query analysis for the ICU tokenizer.
 """
 from typing import Tuple, Dict, List, Optional, Iterator, Any, cast
-from collections import defaultdict
 import dataclasses
 import difflib
 import re
 import dataclasses
 import difflib
 import re
@@ -47,29 +46,6 @@ PENALTY_IN_TOKEN_BREAK = {
 }
 
 
 }
 
 
-WordDict = Dict[str, List[qmod.TokenRange]]
-
-
-def extract_words(query: qmod.QueryStruct, start: int,  words: WordDict) -> None:
-    """ Add all combinations of words in the terms list starting with
-        the term leading into node 'start'.
-
-        The words found will be added into the 'words' dictionary with
-        their start and end position.
-    """
-    nodes = query.nodes
-    total = len(nodes)
-    base_penalty = PENALTY_IN_TOKEN_BREAK[qmod.BREAK_WORD]
-    for first in range(start, total):
-        word = nodes[first].term_lookup
-        penalty = base_penalty
-        words[word].append(qmod.TokenRange(first - 1, first, penalty=penalty))
-        for last in range(first + 1, min(first + 20, total)):
-            word = ' '.join((word, nodes[last].term_lookup))
-            penalty += nodes[last - 1].penalty
-            words[word].append(qmod.TokenRange(first - 1, last, penalty=penalty))
-
-
 @dataclasses.dataclass
 class ICUToken(qmod.Token):
     """ Specialised token for ICU tokenizer.
 @dataclasses.dataclass
 class ICUToken(qmod.Token):
     """ Specialised token for ICU tokenizer.
@@ -203,8 +179,9 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
         if not query.source:
             return query
 
         if not query.source:
             return query
 
-        words = self.split_query(query)
+        self.split_query(query)
         log().var_dump('Transliterated query', lambda: query.get_transliterated_query())
         log().var_dump('Transliterated query', lambda: query.get_transliterated_query())
+        words = query.extract_words(base_penalty=PENALTY_IN_TOKEN_BREAK[qmod.BREAK_WORD])
 
         for row in await self.lookup_in_db(list(words.keys())):
             for trange in words[row.word_token]:
 
         for row in await self.lookup_in_db(list(words.keys())):
             for trange in words[row.word_token]:
@@ -235,14 +212,9 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
         """
         return cast(str, self.normalizer.transliterate(text)).strip('-: ')
 
         """
         return cast(str, self.normalizer.transliterate(text)).strip('-: ')
 
-    def split_query(self, query: qmod.QueryStruct) -> WordDict:
+    def split_query(self, query: qmod.QueryStruct) -> None:
         """ Transliterate the phrases and split them into tokens.
         """ Transliterate the phrases and split them into tokens.
-
-            Returns a dictionary of words for lookup together
-            with their position.
         """
         """
-        phrase_start = 1
-        words: WordDict = defaultdict(list)
         for phrase in query.source:
             query.nodes[-1].ptype = phrase.ptype
             phrase_split = re.split('([ :-])', phrase.text)
         for phrase in query.source:
             query.nodes[-1].ptype = phrase.ptype
             phrase_split = re.split('([ :-])', phrase.text)
@@ -263,13 +235,8 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
                     query.nodes[-1].adjust_break(breakchar,
                                                  PENALTY_IN_TOKEN_BREAK[breakchar])
 
                     query.nodes[-1].adjust_break(breakchar,
                                                  PENALTY_IN_TOKEN_BREAK[breakchar])
 
-            extract_words(query, phrase_start, words)
-
-            phrase_start = len(query.nodes)
         query.nodes[-1].adjust_break(qmod.BREAK_END, PENALTY_IN_TOKEN_BREAK[qmod.BREAK_END])
 
         query.nodes[-1].adjust_break(qmod.BREAK_END, PENALTY_IN_TOKEN_BREAK[qmod.BREAK_END])
 
-        return words
-
     async def lookup_in_db(self, words: List[str]) -> 'sa.Result[Any]':
         """ Return the token information from the database for the
             given word tokens.
     async def lookup_in_db(self, words: List[str]) -> 'sa.Result[Any]':
         """ Return the token information from the database for the
             given word tokens.
index fcd6763bf89bfb5057d0cbdf42ff01094348b897..07bb685b6d36ed5d145faf4e403aedce7668bb12 100644 (file)
@@ -7,8 +7,9 @@
 """
 Datastructures for a tokenized query.
 """
 """
 Datastructures for a tokenized query.
 """
-from typing import List, Tuple, Optional, Iterator
+from typing import Dict, List, Tuple, Optional, Iterator
 from abc import ABC, abstractmethod
 from abc import ABC, abstractmethod
+from collections import defaultdict
 import dataclasses
 
 
 import dataclasses
 
 
@@ -320,3 +321,34 @@ class QueryStruct:
             For debugging purposes only.
         """
         return ''.join(''.join((n.term_lookup, n.btype)) for n in self.nodes)
             For debugging purposes only.
         """
         return ''.join(''.join((n.term_lookup, n.btype)) for n in self.nodes)
+
+    def extract_words(self, base_penalty: float = 0.0,
+                      start: int = 0,
+                      endpos: Optional[int] = None) -> Dict[str, List[TokenRange]]:
+        """ Add all combinations of words that can be formed from the terms
+            between the given start and endnode. The terms are joined with
+            spaces for each break. Words can never go across a BREAK_PHRASE.
+
+            The functions returns a dictionary of possible words with their
+            position within the query and a penalty. The penalty is computed
+            from the base_penalty plus the penalty for each node the word
+            crosses.
+        """
+        if endpos is None:
+            endpos = len(self.nodes)
+
+        words: Dict[str, List[TokenRange]] = defaultdict(list)
+
+        for first in range(start, endpos - 1):
+            word = self.nodes[first + 1].term_lookup
+            penalty = base_penalty
+            words[word].append(TokenRange(first, first + 1, penalty=penalty))
+            if self.nodes[first + 1].btype != BREAK_PHRASE:
+                for last in range(first + 2, min(first + 20, endpos)):
+                    word = ' '.join((word, self.nodes[last].term_lookup))
+                    penalty += self.nodes[last - 1].penalty
+                    words[word].append(TokenRange(first, last, penalty=penalty))
+                    if self.nodes[last].btype == BREAK_PHRASE:
+                        break
+
+        return words
index c39094f0e8ac03c78d79828681c62993d4c0194c..bfed38df57f5b5993dd0cdf804ac79288cd81d46 100644 (file)
@@ -46,3 +46,20 @@ def test_token_range_unimplemented_ops():
         nq.TokenRange(1, 3) <= nq.TokenRange(10, 12)
     with pytest.raises(TypeError):
         nq.TokenRange(1, 3) >= nq.TokenRange(10, 12)
         nq.TokenRange(1, 3) <= nq.TokenRange(10, 12)
     with pytest.raises(TypeError):
         nq.TokenRange(1, 3) >= nq.TokenRange(10, 12)
+
+
+def test_query_extract_words():
+    q = nq.QueryStruct([])
+    q.add_node(nq.BREAK_WORD, nq.PHRASE_ANY, 0.1, '12', '')
+    q.add_node(nq.BREAK_TOKEN, nq.PHRASE_ANY, 0.0, 'ab', '')
+    q.add_node(nq.BREAK_PHRASE, nq.PHRASE_ANY, 0.0, '12', '')
+    q.add_node(nq.BREAK_END, nq.PHRASE_ANY, 0.5, 'hallo', '')
+
+    words = q.extract_words(base_penalty=1.0)
+
+    assert set(words.keys()) \
+             == {'12', 'ab', 'hallo', '12 ab', 'ab 12', '12 ab 12'}
+    assert sorted(words['12']) == [nq.TokenRange(0, 1, 1.0), nq.TokenRange(2, 3, 1.0)]
+    assert words['12 ab'] == [nq.TokenRange(0, 2, 1.1)]
+    assert words['hallo'] == [nq.TokenRange(3, 4, 1.0)]
+