From: Sarah Hoffmann Date: Wed, 26 Feb 2025 16:22:14 +0000 (+0100) Subject: make word generation from query a class method X-Git-Url: https://git.openstreetmap.org./nominatim.git/commitdiff_plain/6759edfb5d4cf68856d04b0208d60b48448068b1 make word generation from query a class method --- diff --git a/src/nominatim_api/search/icu_tokenizer.py b/src/nominatim_api/search/icu_tokenizer.py index 60e712d5..e6bba95c 100644 --- a/src/nominatim_api/search/icu_tokenizer.py +++ b/src/nominatim_api/search/icu_tokenizer.py @@ -8,7 +8,6 @@ Implementation of query analysis for the ICU tokenizer. """ from typing import Tuple, Dict, List, Optional, Iterator, Any, cast -from collections import defaultdict import dataclasses import difflib import re @@ -47,29 +46,6 @@ PENALTY_IN_TOKEN_BREAK = { } -WordDict = Dict[str, List[qmod.TokenRange]] - - -def extract_words(query: qmod.QueryStruct, start: int, words: WordDict) -> None: - """ Add all combinations of words in the terms list starting with - the term leading into node 'start'. - - The words found will be added into the 'words' dictionary with - their start and end position. - """ - nodes = query.nodes - total = len(nodes) - base_penalty = PENALTY_IN_TOKEN_BREAK[qmod.BREAK_WORD] - for first in range(start, total): - word = nodes[first].term_lookup - penalty = base_penalty - words[word].append(qmod.TokenRange(first - 1, first, penalty=penalty)) - for last in range(first + 1, min(first + 20, total)): - word = ' '.join((word, nodes[last].term_lookup)) - penalty += nodes[last - 1].penalty - words[word].append(qmod.TokenRange(first - 1, last, penalty=penalty)) - - @dataclasses.dataclass class ICUToken(qmod.Token): """ Specialised token for ICU tokenizer. @@ -203,8 +179,9 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer): if not query.source: return query - words = self.split_query(query) + self.split_query(query) log().var_dump('Transliterated query', lambda: query.get_transliterated_query()) + words = query.extract_words(base_penalty=PENALTY_IN_TOKEN_BREAK[qmod.BREAK_WORD]) for row in await self.lookup_in_db(list(words.keys())): for trange in words[row.word_token]: @@ -235,14 +212,9 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer): """ return cast(str, self.normalizer.transliterate(text)).strip('-: ') - def split_query(self, query: qmod.QueryStruct) -> WordDict: + def split_query(self, query: qmod.QueryStruct) -> None: """ Transliterate the phrases and split them into tokens. - - Returns a dictionary of words for lookup together - with their position. """ - phrase_start = 1 - words: WordDict = defaultdict(list) for phrase in query.source: query.nodes[-1].ptype = phrase.ptype phrase_split = re.split('([ :-])', phrase.text) @@ -263,13 +235,8 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer): query.nodes[-1].adjust_break(breakchar, PENALTY_IN_TOKEN_BREAK[breakchar]) - extract_words(query, phrase_start, words) - - phrase_start = len(query.nodes) query.nodes[-1].adjust_break(qmod.BREAK_END, PENALTY_IN_TOKEN_BREAK[qmod.BREAK_END]) - return words - async def lookup_in_db(self, words: List[str]) -> 'sa.Result[Any]': """ Return the token information from the database for the given word tokens. diff --git a/src/nominatim_api/search/query.py b/src/nominatim_api/search/query.py index fcd6763b..07bb685b 100644 --- a/src/nominatim_api/search/query.py +++ b/src/nominatim_api/search/query.py @@ -7,8 +7,9 @@ """ Datastructures for a tokenized query. """ -from typing import List, Tuple, Optional, Iterator +from typing import Dict, List, Tuple, Optional, Iterator from abc import ABC, abstractmethod +from collections import defaultdict import dataclasses @@ -320,3 +321,34 @@ class QueryStruct: For debugging purposes only. """ return ''.join(''.join((n.term_lookup, n.btype)) for n in self.nodes) + + def extract_words(self, base_penalty: float = 0.0, + start: int = 0, + endpos: Optional[int] = None) -> Dict[str, List[TokenRange]]: + """ Add all combinations of words that can be formed from the terms + between the given start and endnode. The terms are joined with + spaces for each break. Words can never go across a BREAK_PHRASE. + + The functions returns a dictionary of possible words with their + position within the query and a penalty. The penalty is computed + from the base_penalty plus the penalty for each node the word + crosses. + """ + if endpos is None: + endpos = len(self.nodes) + + words: Dict[str, List[TokenRange]] = defaultdict(list) + + for first in range(start, endpos - 1): + word = self.nodes[first + 1].term_lookup + penalty = base_penalty + words[word].append(TokenRange(first, first + 1, penalty=penalty)) + if self.nodes[first + 1].btype != BREAK_PHRASE: + for last in range(first + 2, min(first + 20, endpos)): + word = ' '.join((word, self.nodes[last].term_lookup)) + penalty += self.nodes[last - 1].penalty + words[word].append(TokenRange(first, last, penalty=penalty)) + if self.nodes[last].btype == BREAK_PHRASE: + break + + return words diff --git a/test/python/api/search/test_query.py b/test/python/api/search/test_query.py index c39094f0..bfed38df 100644 --- a/test/python/api/search/test_query.py +++ b/test/python/api/search/test_query.py @@ -46,3 +46,20 @@ def test_token_range_unimplemented_ops(): nq.TokenRange(1, 3) <= nq.TokenRange(10, 12) with pytest.raises(TypeError): nq.TokenRange(1, 3) >= nq.TokenRange(10, 12) + + +def test_query_extract_words(): + q = nq.QueryStruct([]) + q.add_node(nq.BREAK_WORD, nq.PHRASE_ANY, 0.1, '12', '') + q.add_node(nq.BREAK_TOKEN, nq.PHRASE_ANY, 0.0, 'ab', '') + q.add_node(nq.BREAK_PHRASE, nq.PHRASE_ANY, 0.0, '12', '') + q.add_node(nq.BREAK_END, nq.PHRASE_ANY, 0.5, 'hallo', '') + + words = q.extract_words(base_penalty=1.0) + + assert set(words.keys()) \ + == {'12', 'ab', 'hallo', '12 ab', 'ab 12', '12 ab 12'} + assert sorted(words['12']) == [nq.TokenRange(0, 1, 1.0), nq.TokenRange(2, 3, 1.0)] + assert words['12 ab'] == [nq.TokenRange(0, 2, 1.1)] + assert words['hallo'] == [nq.TokenRange(3, 4, 1.0)] +