X-Git-Url: https://git.openstreetmap.org./nominatim.git/blobdiff_plain/004883bdb1cfdfea053cb59fe32792c4e368e88c..f923304eead3cb9e9cfad8f41c33df1fdc1a16fd:/nominatim/api/search/icu_tokenizer.py?ds=sidebyside diff --git a/nominatim/api/search/icu_tokenizer.py b/nominatim/api/search/icu_tokenizer.py index 14698a28..eb90c122 100644 --- a/nominatim/api/search/icu_tokenizer.py +++ b/nominatim/api/search/icu_tokenizer.py @@ -8,7 +8,6 @@ Implementation of query analysis for the ICU tokenizer. """ from typing import Tuple, Dict, List, Optional, NamedTuple, Iterator, Any, cast -from copy import copy from collections import defaultdict import dataclasses import difflib @@ -21,10 +20,8 @@ from nominatim.typing import SaRow from nominatim.api.connection import SearchConnection from nominatim.api.logging import log from nominatim.api.search import query as qmod - -# XXX: TODO -class AbstractQueryAnalyzer: - pass +from nominatim.api.search.query_analyzer_factory import AbstractQueryAnalyzer +from nominatim.db.sqlalchemy_types import Json DB_TO_TOKEN_TYPE = { @@ -86,7 +83,7 @@ class ICUToken(qmod.Token): seq = difflib.SequenceMatcher(a=self.lookup_word, b=norm) distance = 0 for tag, afrom, ato, bfrom, bto in seq.get_opcodes(): - if tag == 'delete' and (afrom == 0 or ato == len(self.lookup_word)): + if tag in ('delete', 'insert') and (afrom == 0 or ato == len(self.lookup_word)): distance += 1 elif tag == 'replace': distance += max((ato-afrom), (bto-bfrom)) @@ -100,14 +97,21 @@ class ICUToken(qmod.Token): """ Create a ICUToken from the row of the word table. """ count = 1 if row.info is None else row.info.get('count', 1) + addr_count = 1 if row.info is None else row.info.get('addr_count', 1) penalty = 0.0 if row.type == 'w': penalty = 0.3 + elif row.type == 'W': + if len(row.word_token) == 1 and row.word_token == row.word: + penalty = 0.2 if row.word.isdigit() else 0.3 elif row.type == 'H': penalty = sum(0.1 for c in row.word_token if c != ' ' and not c.isdigit()) if all(not c.isdigit() for c in row.word_token): penalty += 0.2 * (len(row.word_token) - 1) + elif row.type == 'C': + if len(row.word_token) == 1: + penalty = 0.3 if row.info is None: lookup_word = row.word @@ -118,9 +122,10 @@ class ICUToken(qmod.Token): else: lookup_word = row.word_token - return ICUToken(penalty=penalty, token=row.word_id, count=count, + return ICUToken(penalty=penalty, token=row.word_id, count=max(1, count), lookup_word=lookup_word, is_indexed=True, - word_token=row.word_token, info=row.info) + word_token=row.word_token, info=row.info, + addr_count=max(1, addr_count)) @@ -136,10 +141,19 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer): async def setup(self) -> None: """ Set up static data structures needed for the analysis. """ - rules = await self.conn.get_property('tokenizer_import_normalisation') - self.normalizer = Transliterator.createFromRules("normalization", rules) - rules = await self.conn.get_property('tokenizer_import_transliteration') - self.transliterator = Transliterator.createFromRules("transliteration", rules) + async def _make_normalizer() -> Any: + rules = await self.conn.get_property('tokenizer_import_normalisation') + return Transliterator.createFromRules("normalization", rules) + + self.normalizer = await self.conn.get_cached_value('ICUTOK', 'normalizer', + _make_normalizer) + + async def _make_transliterator() -> Any: + rules = await self.conn.get_property('tokenizer_import_transliteration') + return Transliterator.createFromRules("transliteration", rules) + + self.transliterator = await self.conn.get_cached_value('ICUTOK', 'transliterator', + _make_transliterator) if 'word' not in self.conn.t.meta.tables: sa.Table('word', self.conn.t.meta, @@ -147,7 +161,7 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer): sa.Column('word_token', sa.Text, nullable=False), sa.Column('type', sa.Text, nullable=False), sa.Column('word', sa.Text), - sa.Column('info', self.conn.t.types.Json)) + sa.Column('info', Json)) async def analyze_query(self, phrases: List[qmod.Phrase]) -> qmod.QueryStruct: @@ -156,7 +170,7 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer): """ log().section('Analyze query (using ICU tokenizer)') normalized = list(filter(lambda p: p.text, - (qmod.Phrase(p.ptype, self.normalizer.transliterate(p.text)) + (qmod.Phrase(p.ptype, self.normalize_text(p.text)) for p in phrases))) query = qmod.QueryStruct(normalized) log().var_dump('Normalized query', query.source) @@ -172,13 +186,12 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer): if row.type == 'S': if row.info['op'] in ('in', 'near'): if trange.start == 0: - query.add_token(trange, qmod.TokenType.CATEGORY, token) + query.add_token(trange, qmod.TokenType.NEAR_ITEM, token) else: - query.add_token(trange, qmod.TokenType.QUALIFIER, token) - if trange.start == 0 or trange.end == query.num_token_slots(): - token = copy(token) - token.penalty += 0.1 * (query.num_token_slots()) - query.add_token(trange, qmod.TokenType.CATEGORY, token) + if trange.start == 0 and trange.end == query.num_token_slots(): + query.add_token(trange, qmod.TokenType.NEAR_ITEM, token) + else: + query.add_token(trange, qmod.TokenType.QUALIFIER, token) else: query.add_token(trange, DB_TO_TOKEN_TYPE[row.type], token) @@ -190,6 +203,14 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer): return query + def normalize_text(self, text: str) -> str: + """ Bring the given text into a normalized form. That is the + standardized form search will work with. All information removed + at this stage is inevitably lost. + """ + return cast(str, self.normalizer.transliterate(text)) + + def split_query(self, query: qmod.QueryStruct) -> Tuple[QueryParts, WordDict]: """ Transliterate the phrases and split them into tokens. @@ -238,7 +259,7 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer): if len(part.token) <= 4 and part[0].isdigit()\ and not node.has_tokens(i+1, qmod.TokenType.HOUSENUMBER): query.add_token(qmod.TokenRange(i, i+1), qmod.TokenType.HOUSENUMBER, - ICUToken(0.5, 0, 1, part.token, True, part.token, None)) + ICUToken(0.5, 0, 1, 1, part.token, True, part.token, None)) def rerank_tokens(self, query: qmod.QueryStruct, parts: QueryParts) -> None: @@ -251,12 +272,11 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer): and (repl.ttype != qmod.TokenType.HOUSENUMBER or len(tlist.tokens[0].lookup_word) > 4): repl.add_penalty(0.39) - elif tlist.ttype == qmod.TokenType.HOUSENUMBER: + elif tlist.ttype == qmod.TokenType.HOUSENUMBER \ + and len(tlist.tokens[0].lookup_word) <= 3: if any(c.isdigit() for c in tlist.tokens[0].lookup_word): for repl in node.starting: - if repl.end == tlist.end and repl.ttype != qmod.TokenType.HOUSENUMBER \ - and (repl.ttype != qmod.TokenType.HOUSENUMBER - or len(tlist.tokens[0].lookup_word) <= 3): + if repl.end == tlist.end and repl.ttype != qmod.TokenType.HOUSENUMBER: repl.add_penalty(0.5 - tlist.tokens[0].penalty) elif tlist.ttype not in (qmod.TokenType.COUNTRY, qmod.TokenType.PARTIAL): norm = parts[i].normalized