From e362a965e167dadd828a4a4b7fc58c6076e6586a Mon Sep 17 00:00:00 2001 From: Sarah Hoffmann Date: Wed, 26 Feb 2025 14:37:08 +0100 Subject: [PATCH] search: merge QueryPart array with QueryNodes The basic information on terms is pretty much always used together with the node inforamtion. Merging them together saves some allocation while making lookup easier at the same time. --- src/nominatim_api/search/icu_tokenizer.py | 110 ++++++++---------- src/nominatim_api/search/query.py | 39 ++++++- .../api/search/test_api_search_query.py | 34 +++--- 3 files changed, 100 insertions(+), 83 deletions(-) diff --git a/src/nominatim_api/search/icu_tokenizer.py b/src/nominatim_api/search/icu_tokenizer.py index 1a449276..60e712d5 100644 --- a/src/nominatim_api/search/icu_tokenizer.py +++ b/src/nominatim_api/search/icu_tokenizer.py @@ -47,40 +47,27 @@ PENALTY_IN_TOKEN_BREAK = { } -@dataclasses.dataclass -class QueryPart: - """ Normalized and transliterated form of a single term in the query. - - When the term came out of a split during the transliteration, - the normalized string is the full word before transliteration. - Check the subsequent break type to figure out if the word is - continued. - - Penalty is the break penalty for the break following the token. - """ - token: str - normalized: str - penalty: float - - -QueryParts = List[QueryPart] WordDict = Dict[str, List[qmod.TokenRange]] -def extract_words(terms: List[QueryPart], start: int, words: WordDict) -> None: - """ Add all combinations of words in the terms list after the - given position to the word list. +def extract_words(query: qmod.QueryStruct, start: int, words: WordDict) -> None: + """ Add all combinations of words in the terms list starting with + the term leading into node 'start'. + + The words found will be added into the 'words' dictionary with + their start and end position. """ - total = len(terms) + nodes = query.nodes + total = len(nodes) base_penalty = PENALTY_IN_TOKEN_BREAK[qmod.BREAK_WORD] for first in range(start, total): - word = terms[first].token + word = nodes[first].term_lookup penalty = base_penalty - words[word].append(qmod.TokenRange(first, first + 1, penalty=penalty)) + words[word].append(qmod.TokenRange(first - 1, first, penalty=penalty)) for last in range(first + 1, min(first + 20, total)): - word = ' '.join((word, terms[last].token)) - penalty += terms[last - 1].penalty - words[word].append(qmod.TokenRange(first, last + 1, penalty=penalty)) + word = ' '.join((word, nodes[last].term_lookup)) + penalty += nodes[last - 1].penalty + words[word].append(qmod.TokenRange(first - 1, last, penalty=penalty)) @dataclasses.dataclass @@ -216,8 +203,8 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer): if not query.source: return query - parts, words = self.split_query(query) - log().var_dump('Transliterated query', lambda: _dump_transliterated(query, parts)) + words = self.split_query(query) + log().var_dump('Transliterated query', lambda: query.get_transliterated_query()) for row in await self.lookup_in_db(list(words.keys())): for trange in words[row.word_token]: @@ -234,8 +221,8 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer): else: query.add_token(trange, DB_TO_TOKEN_TYPE[row.type], token) - self.add_extra_tokens(query, parts) - self.rerank_tokens(query, parts) + self.add_extra_tokens(query) + self.rerank_tokens(query) log().table_dump('Word tokens', _dump_word_tokens(query)) @@ -248,15 +235,13 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer): """ return cast(str, self.normalizer.transliterate(text)).strip('-: ') - def split_query(self, query: qmod.QueryStruct) -> Tuple[QueryParts, WordDict]: + def split_query(self, query: qmod.QueryStruct) -> WordDict: """ Transliterate the phrases and split them into tokens. - Returns the list of transliterated tokens together with their - normalized form and a dictionary of words for lookup together + Returns a dictionary of words for lookup together with their position. """ - parts: QueryParts = [] - phrase_start = 0 + phrase_start = 1 words: WordDict = defaultdict(list) for phrase in query.source: query.nodes[-1].ptype = phrase.ptype @@ -272,18 +257,18 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer): if trans: for term in trans.split(' '): if term: - parts.append(QueryPart(term, word, - PENALTY_IN_TOKEN_BREAK[qmod.BREAK_TOKEN])) - query.add_node(qmod.BREAK_TOKEN, phrase.ptype) - query.nodes[-1].btype = breakchar - parts[-1].penalty = PENALTY_IN_TOKEN_BREAK[breakchar] + query.add_node(qmod.BREAK_TOKEN, phrase.ptype, + PENALTY_IN_TOKEN_BREAK[qmod.BREAK_TOKEN], + term, word) + query.nodes[-1].adjust_break(breakchar, + PENALTY_IN_TOKEN_BREAK[breakchar]) - extract_words(parts, phrase_start, words) + extract_words(query, phrase_start, words) - phrase_start = len(parts) - query.nodes[-1].btype = qmod.BREAK_END + phrase_start = len(query.nodes) + query.nodes[-1].adjust_break(qmod.BREAK_END, PENALTY_IN_TOKEN_BREAK[qmod.BREAK_END]) - return parts, words + return words async def lookup_in_db(self, words: List[str]) -> 'sa.Result[Any]': """ Return the token information from the database for the @@ -292,18 +277,23 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer): t = self.conn.t.meta.tables['word'] return await self.conn.execute(t.select().where(t.c.word_token.in_(words))) - def add_extra_tokens(self, query: qmod.QueryStruct, parts: QueryParts) -> None: + def add_extra_tokens(self, query: qmod.QueryStruct) -> None: """ Add tokens to query that are not saved in the database. """ - for part, node, i in zip(parts, query.nodes, range(1000)): - if len(part.token) <= 4 and part.token.isdigit()\ - and not node.has_tokens(i+1, qmod.TOKEN_HOUSENUMBER): - query.add_token(qmod.TokenRange(i, i+1), qmod.TOKEN_HOUSENUMBER, + need_hnr = False + for i, node in enumerate(query.nodes): + is_full_token = node.btype not in (qmod.BREAK_TOKEN, qmod.BREAK_PART) + if need_hnr and is_full_token \ + and len(node.term_normalized) <= 4 and node.term_normalized.isdigit(): + query.add_token(qmod.TokenRange(i-1, i), qmod.TOKEN_HOUSENUMBER, ICUToken(penalty=0.5, token=0, - count=1, addr_count=1, lookup_word=part.token, - word_token=part.token, info=None)) + count=1, addr_count=1, + lookup_word=node.term_lookup, + word_token=node.term_lookup, info=None)) + + need_hnr = is_full_token and not node.has_tokens(i+1, qmod.TOKEN_HOUSENUMBER) - def rerank_tokens(self, query: qmod.QueryStruct, parts: QueryParts) -> None: + def rerank_tokens(self, query: qmod.QueryStruct) -> None: """ Add penalties to tokens that depend on presence of other token. """ for i, node, tlist in query.iter_token_lists(): @@ -320,21 +310,15 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer): if repl.end == tlist.end and repl.ttype != qmod.TOKEN_HOUSENUMBER: repl.add_penalty(0.5 - tlist.tokens[0].penalty) elif tlist.ttype not in (qmod.TOKEN_COUNTRY, qmod.TOKEN_PARTIAL): - norm = parts[i].normalized - for j in range(i + 1, tlist.end): - if node.btype != qmod.BREAK_TOKEN: - norm += ' ' + parts[j].normalized + norm = ' '.join(n.term_normalized for n in query.nodes[i + 1:tlist.end + 1] + if n.btype != qmod.BREAK_TOKEN) + if not norm: + # Can happen when the token only covers a partial term + norm = query.nodes[i + 1].term_normalized for token in tlist.tokens: cast(ICUToken, token).rematch(norm) -def _dump_transliterated(query: qmod.QueryStruct, parts: QueryParts) -> str: - out = query.nodes[0].btype - for node, part in zip(query.nodes[1:], parts): - out += part.token + node.btype - return out - - def _dump_word_tokens(query: qmod.QueryStruct) -> Iterator[List[Any]]: yield ['type', 'token', 'word_token', 'lookup_word', 'penalty', 'count', 'info'] for node in query.nodes: diff --git a/src/nominatim_api/search/query.py b/src/nominatim_api/search/query.py index 8530c4f2..fcd6763b 100644 --- a/src/nominatim_api/search/query.py +++ b/src/nominatim_api/search/query.py @@ -171,11 +171,33 @@ class TokenList: @dataclasses.dataclass class QueryNode: """ A node of the query representing a break between terms. + + The node also contains information on the source term + ending at the node. The tokens are created from this information. """ btype: BreakType ptype: PhraseType + + penalty: float + """ Penalty for the break at this node. + """ + term_lookup: str + """ Transliterated term following this node. + """ + term_normalized: str + """ Normalised form of term following this node. + When the token resulted from a split during transliteration, + then this string contains the complete source term. + """ + starting: List[TokenList] = dataclasses.field(default_factory=list) + def adjust_break(self, btype: BreakType, penalty: float) -> None: + """ Change the break type and penalty for this node. + """ + self.btype = btype + self.penalty = penalty + def has_tokens(self, end: int, *ttypes: TokenType) -> bool: """ Check if there are tokens of the given types ending at the given node. @@ -218,19 +240,22 @@ class QueryStruct: def __init__(self, source: List[Phrase]) -> None: self.source = source self.nodes: List[QueryNode] = \ - [QueryNode(BREAK_START, source[0].ptype if source else PHRASE_ANY)] + [QueryNode(BREAK_START, source[0].ptype if source else PHRASE_ANY, + 0.0, '', '')] def num_token_slots(self) -> int: """ Return the length of the query in vertice steps. """ return len(self.nodes) - 1 - def add_node(self, btype: BreakType, ptype: PhraseType) -> None: + def add_node(self, btype: BreakType, ptype: PhraseType, + break_penalty: float = 0.0, + term_lookup: str = '', term_normalized: str = '') -> None: """ Append a new break node with the given break type. The phrase type denotes the type for any tokens starting at the node. """ - self.nodes.append(QueryNode(btype, ptype)) + self.nodes.append(QueryNode(btype, ptype, break_penalty, term_lookup, term_normalized)) def add_token(self, trange: TokenRange, ttype: TokenType, token: Token) -> None: """ Add a token to the query. 'start' and 'end' are the indexes of the @@ -287,3 +312,11 @@ class QueryStruct: if t.token == token: return f"[{tlist.ttype}]{t.lookup_word}" return 'None' + + def get_transliterated_query(self) -> str: + """ Return a string representation of the transliterated query + with the character representation of the different break types. + + For debugging purposes only. + """ + return ''.join(''.join((n.term_lookup, n.btype)) for n in self.nodes) diff --git a/test/python/api/search/test_api_search_query.py b/test/python/api/search/test_api_search_query.py index 412a5bf2..08a1f7aa 100644 --- a/test/python/api/search/test_api_search_query.py +++ b/test/python/api/search/test_api_search_query.py @@ -21,6 +21,9 @@ def mktoken(tid: int): return MyToken(penalty=3.0, token=tid, count=1, addr_count=1, lookup_word='foo') +@pytest.fixture +def qnode(): + return query.QueryNode(query.BREAK_PHRASE, query.PHRASE_ANY, 0.0 ,'', '') @pytest.mark.parametrize('ptype,ttype', [(query.PHRASE_ANY, 'W'), (query.PHRASE_AMENITY, 'Q'), @@ -37,27 +40,24 @@ def test_phrase_incompatible(ptype): assert not query._phrase_compatible_with(ptype, query.TOKEN_PARTIAL, True) -def test_query_node_empty(): - qn = query.QueryNode(query.BREAK_PHRASE, query.PHRASE_ANY) +def test_query_node_empty(qnode): + assert not qnode.has_tokens(3, query.TOKEN_PARTIAL) + assert qnode.get_tokens(3, query.TOKEN_WORD) is None - assert not qn.has_tokens(3, query.TOKEN_PARTIAL) - assert qn.get_tokens(3, query.TOKEN_WORD) is None +def test_query_node_with_content(qnode): + qnode.starting.append(query.TokenList(2, query.TOKEN_PARTIAL, [mktoken(100), mktoken(101)])) + qnode.starting.append(query.TokenList(2, query.TOKEN_WORD, [mktoken(1000)])) -def test_query_node_with_content(): - qn = query.QueryNode(query.BREAK_PHRASE, query.PHRASE_ANY) - qn.starting.append(query.TokenList(2, query.TOKEN_PARTIAL, [mktoken(100), mktoken(101)])) - qn.starting.append(query.TokenList(2, query.TOKEN_WORD, [mktoken(1000)])) + assert not qnode.has_tokens(3, query.TOKEN_PARTIAL) + assert not qnode.has_tokens(2, query.TOKEN_COUNTRY) + assert qnode.has_tokens(2, query.TOKEN_PARTIAL) + assert qnode.has_tokens(2, query.TOKEN_WORD) - assert not qn.has_tokens(3, query.TOKEN_PARTIAL) - assert not qn.has_tokens(2, query.TOKEN_COUNTRY) - assert qn.has_tokens(2, query.TOKEN_PARTIAL) - assert qn.has_tokens(2, query.TOKEN_WORD) - - assert qn.get_tokens(3, query.TOKEN_PARTIAL) is None - assert qn.get_tokens(2, query.TOKEN_COUNTRY) is None - assert len(qn.get_tokens(2, query.TOKEN_PARTIAL)) == 2 - assert len(qn.get_tokens(2, query.TOKEN_WORD)) == 1 + assert qnode.get_tokens(3, query.TOKEN_PARTIAL) is None + assert qnode.get_tokens(2, query.TOKEN_COUNTRY) is None + assert len(qnode.get_tokens(2, query.TOKEN_PARTIAL)) == 2 + assert len(qnode.get_tokens(2, query.TOKEN_WORD)) == 1 def test_query_struct_empty(): -- 2.39.5