]> git.openstreetmap.org Git - nominatim.git/commitdiff
make word generation from query a class method
authorSarah Hoffmann <lonvia@denofr.de>
Wed, 26 Feb 2025 16:22:14 +0000 (17:22 +0100)
committerSarah Hoffmann <lonvia@denofr.de>
Tue, 4 Mar 2025 07:57:37 +0000 (08:57 +0100)
src/nominatim_api/search/icu_tokenizer.py
src/nominatim_api/search/query.py
test/python/api/search/test_query.py

index 60e712d59b1af8ce2c0daebc95ce683c948ccd35..e6bba95c6fd071d871ac1813265b734f3d6fdd13 100644 (file)
@@ -8,7 +8,6 @@
 Implementation of query analysis for the ICU tokenizer.
 """
 from typing import Tuple, Dict, List, Optional, Iterator, Any, cast
-from collections import defaultdict
 import dataclasses
 import difflib
 import re
@@ -47,29 +46,6 @@ PENALTY_IN_TOKEN_BREAK = {
 }
 
 
-WordDict = Dict[str, List[qmod.TokenRange]]
-
-
-def extract_words(query: qmod.QueryStruct, start: int,  words: WordDict) -> None:
-    """ Add all combinations of words in the terms list starting with
-        the term leading into node 'start'.
-
-        The words found will be added into the 'words' dictionary with
-        their start and end position.
-    """
-    nodes = query.nodes
-    total = len(nodes)
-    base_penalty = PENALTY_IN_TOKEN_BREAK[qmod.BREAK_WORD]
-    for first in range(start, total):
-        word = nodes[first].term_lookup
-        penalty = base_penalty
-        words[word].append(qmod.TokenRange(first - 1, first, penalty=penalty))
-        for last in range(first + 1, min(first + 20, total)):
-            word = ' '.join((word, nodes[last].term_lookup))
-            penalty += nodes[last - 1].penalty
-            words[word].append(qmod.TokenRange(first - 1, last, penalty=penalty))
-
-
 @dataclasses.dataclass
 class ICUToken(qmod.Token):
     """ Specialised token for ICU tokenizer.
@@ -203,8 +179,9 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
         if not query.source:
             return query
 
-        words = self.split_query(query)
+        self.split_query(query)
         log().var_dump('Transliterated query', lambda: query.get_transliterated_query())
+        words = query.extract_words(base_penalty=PENALTY_IN_TOKEN_BREAK[qmod.BREAK_WORD])
 
         for row in await self.lookup_in_db(list(words.keys())):
             for trange in words[row.word_token]:
@@ -235,14 +212,9 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
         """
         return cast(str, self.normalizer.transliterate(text)).strip('-: ')
 
-    def split_query(self, query: qmod.QueryStruct) -> WordDict:
+    def split_query(self, query: qmod.QueryStruct) -> None:
         """ Transliterate the phrases and split them into tokens.
-
-            Returns a dictionary of words for lookup together
-            with their position.
         """
-        phrase_start = 1
-        words: WordDict = defaultdict(list)
         for phrase in query.source:
             query.nodes[-1].ptype = phrase.ptype
             phrase_split = re.split('([ :-])', phrase.text)
@@ -263,13 +235,8 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
                     query.nodes[-1].adjust_break(breakchar,
                                                  PENALTY_IN_TOKEN_BREAK[breakchar])
 
-            extract_words(query, phrase_start, words)
-
-            phrase_start = len(query.nodes)
         query.nodes[-1].adjust_break(qmod.BREAK_END, PENALTY_IN_TOKEN_BREAK[qmod.BREAK_END])
 
-        return words
-
     async def lookup_in_db(self, words: List[str]) -> 'sa.Result[Any]':
         """ Return the token information from the database for the
             given word tokens.
index fcd6763bf89bfb5057d0cbdf42ff01094348b897..07bb685b6d36ed5d145faf4e403aedce7668bb12 100644 (file)
@@ -7,8 +7,9 @@
 """
 Datastructures for a tokenized query.
 """
-from typing import List, Tuple, Optional, Iterator
+from typing import Dict, List, Tuple, Optional, Iterator
 from abc import ABC, abstractmethod
+from collections import defaultdict
 import dataclasses
 
 
@@ -320,3 +321,34 @@ class QueryStruct:
             For debugging purposes only.
         """
         return ''.join(''.join((n.term_lookup, n.btype)) for n in self.nodes)
+
+    def extract_words(self, base_penalty: float = 0.0,
+                      start: int = 0,
+                      endpos: Optional[int] = None) -> Dict[str, List[TokenRange]]:
+        """ Add all combinations of words that can be formed from the terms
+            between the given start and endnode. The terms are joined with
+            spaces for each break. Words can never go across a BREAK_PHRASE.
+
+            The functions returns a dictionary of possible words with their
+            position within the query and a penalty. The penalty is computed
+            from the base_penalty plus the penalty for each node the word
+            crosses.
+        """
+        if endpos is None:
+            endpos = len(self.nodes)
+
+        words: Dict[str, List[TokenRange]] = defaultdict(list)
+
+        for first in range(start, endpos - 1):
+            word = self.nodes[first + 1].term_lookup
+            penalty = base_penalty
+            words[word].append(TokenRange(first, first + 1, penalty=penalty))
+            if self.nodes[first + 1].btype != BREAK_PHRASE:
+                for last in range(first + 2, min(first + 20, endpos)):
+                    word = ' '.join((word, self.nodes[last].term_lookup))
+                    penalty += self.nodes[last - 1].penalty
+                    words[word].append(TokenRange(first, last, penalty=penalty))
+                    if self.nodes[last].btype == BREAK_PHRASE:
+                        break
+
+        return words
index c39094f0e8ac03c78d79828681c62993d4c0194c..bfed38df57f5b5993dd0cdf804ac79288cd81d46 100644 (file)
@@ -46,3 +46,20 @@ def test_token_range_unimplemented_ops():
         nq.TokenRange(1, 3) <= nq.TokenRange(10, 12)
     with pytest.raises(TypeError):
         nq.TokenRange(1, 3) >= nq.TokenRange(10, 12)
+
+
+def test_query_extract_words():
+    q = nq.QueryStruct([])
+    q.add_node(nq.BREAK_WORD, nq.PHRASE_ANY, 0.1, '12', '')
+    q.add_node(nq.BREAK_TOKEN, nq.PHRASE_ANY, 0.0, 'ab', '')
+    q.add_node(nq.BREAK_PHRASE, nq.PHRASE_ANY, 0.0, '12', '')
+    q.add_node(nq.BREAK_END, nq.PHRASE_ANY, 0.5, 'hallo', '')
+
+    words = q.extract_words(base_penalty=1.0)
+
+    assert set(words.keys()) \
+             == {'12', 'ab', 'hallo', '12 ab', 'ab 12', '12 ab 12'}
+    assert sorted(words['12']) == [nq.TokenRange(0, 1, 1.0), nq.TokenRange(2, 3, 1.0)]
+    assert words['12 ab'] == [nq.TokenRange(0, 2, 1.1)]
+    assert words['hallo'] == [nq.TokenRange(3, 4, 1.0)]
+