X-Git-Url: https://git.openstreetmap.org./nominatim.git/blobdiff_plain/96ce8f83bda30d1fc96756d36d7f47890f118575..4af0bc6056edf3596b07cfaf975bf195a51b8557:/src/nominatim_api/search/icu_tokenizer.py diff --git a/src/nominatim_api/search/icu_tokenizer.py b/src/nominatim_api/search/icu_tokenizer.py index 3b85f26d..ecc2c1c7 100644 --- a/src/nominatim_api/search/icu_tokenizer.py +++ b/src/nominatim_api/search/icu_tokenizer.py @@ -8,7 +8,6 @@ Implementation of query analysis for the ICU tokenizer. """ from typing import Tuple, Dict, List, Optional, Iterator, Any, cast -from collections import defaultdict import dataclasses import difflib import re @@ -25,7 +24,9 @@ from ..connection import SearchConnection from ..logging import log from . import query as qmod from ..query_preprocessing.config import QueryConfig +from ..query_preprocessing.base import QueryProcessingFunc from .query_analyzer_factory import AbstractQueryAnalyzer +from .postcode_parser import PostcodeParser DB_TO_TOKEN_TYPE = { @@ -47,42 +48,6 @@ PENALTY_IN_TOKEN_BREAK = { } -@dataclasses.dataclass -class QueryPart: - """ Normalized and transliterated form of a single term in the query. - - When the term came out of a split during the transliteration, - the normalized string is the full word before transliteration. - Check the subsequent break type to figure out if the word is - continued. - - Penalty is the break penalty for the break following the token. - """ - token: str - normalized: str - penalty: float - - -QueryParts = List[QueryPart] -WordDict = Dict[str, List[qmod.TokenRange]] - - -def extract_words(terms: List[QueryPart], start: int, words: WordDict) -> None: - """ Add all combinations of words in the terms list after the - given position to the word list. - """ - total = len(terms) - base_penalty = PENALTY_IN_TOKEN_BREAK[qmod.BREAK_WORD] - for first in range(start, total): - word = terms[first].token - penalty = base_penalty - words[word].append(qmod.TokenRange(first, first + 1, penalty=penalty)) - for last in range(first + 1, min(first + 20, total)): - word = ' '.join((word, terms[last].token)) - penalty += terms[last - 1].penalty - words[word].append(qmod.TokenRange(first, last + 1, penalty=penalty)) - - @dataclasses.dataclass class ICUToken(qmod.Token): """ Specialised token for ICU tokenizer. @@ -148,60 +113,51 @@ class ICUToken(qmod.Token): addr_count=max(1, addr_count)) -class ICUQueryAnalyzer(AbstractQueryAnalyzer): - """ Converter for query strings into a tokenized query - using the tokens created by a ICU tokenizer. - """ - def __init__(self, conn: SearchConnection) -> None: - self.conn = conn - - async def setup(self) -> None: - """ Set up static data structures needed for the analysis. - """ - async def _make_normalizer() -> Any: - rules = await self.conn.get_property('tokenizer_import_normalisation') - return Transliterator.createFromRules("normalization", rules) - - self.normalizer = await self.conn.get_cached_value('ICUTOK', 'normalizer', - _make_normalizer) - - async def _make_transliterator() -> Any: - rules = await self.conn.get_property('tokenizer_import_transliteration') - return Transliterator.createFromRules("transliteration", rules) - - self.transliterator = await self.conn.get_cached_value('ICUTOK', 'transliterator', - _make_transliterator) - - await self._setup_preprocessing() - - if 'word' not in self.conn.t.meta.tables: - sa.Table('word', self.conn.t.meta, - sa.Column('word_id', sa.Integer), - sa.Column('word_token', sa.Text, nullable=False), - sa.Column('type', sa.Text, nullable=False), - sa.Column('word', sa.Text), - sa.Column('info', Json)) +@dataclasses.dataclass +class ICUAnalyzerConfig: + postcode_parser: PostcodeParser + normalizer: Transliterator + transliterator: Transliterator + preprocessors: List[QueryProcessingFunc] - async def _setup_preprocessing(self) -> None: - """ Load the rules for preprocessing and set up the handlers. - """ + @staticmethod + async def create(conn: SearchConnection) -> 'ICUAnalyzerConfig': + rules = await conn.get_property('tokenizer_import_normalisation') + normalizer = Transliterator.createFromRules("normalization", rules) - rules = self.conn.config.load_sub_configuration('icu_tokenizer.yaml', - config='TOKENIZER_CONFIG') - preprocessing_rules = rules.get('query-preprocessing', []) + rules = await conn.get_property('tokenizer_import_transliteration') + transliterator = Transliterator.createFromRules("transliteration", rules) - self.preprocessors = [] + preprocessing_rules = conn.config.load_sub_configuration('icu_tokenizer.yaml', + config='TOKENIZER_CONFIG')\ + .get('query-preprocessing', []) + preprocessors: List[QueryProcessingFunc] = [] for func in preprocessing_rules: if 'step' not in func: raise UsageError("Preprocessing rule is missing the 'step' attribute.") if not isinstance(func['step'], str): raise UsageError("'step' attribute must be a simple string.") - module = self.conn.config.load_plugin_module( + module = conn.config.load_plugin_module( func['step'], 'nominatim_api.query_preprocessing') - self.preprocessors.append( - module.create(QueryConfig(func).set_normalizer(self.normalizer))) + preprocessors.append( + module.create(QueryConfig(func).set_normalizer(normalizer))) + + return ICUAnalyzerConfig(PostcodeParser(conn.config), + normalizer, transliterator, preprocessors) + + +class ICUQueryAnalyzer(AbstractQueryAnalyzer): + """ Converter for query strings into a tokenized query + using the tokens created by a ICU tokenizer. + """ + def __init__(self, conn: SearchConnection, config: ICUAnalyzerConfig) -> None: + self.conn = conn + self.postcode_parser = config.postcode_parser + self.normalizer = config.normalizer + self.transliterator = config.transliterator + self.preprocessors = config.preprocessors async def analyze_query(self, phrases: List[qmod.Phrase]) -> qmod.QueryStruct: """ Analyze the given list of phrases and return the @@ -222,8 +178,9 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer): if not query.source: return query - parts, words = self.split_query(query) - log().var_dump('Transliterated query', lambda: _dump_transliterated(query, parts)) + self.split_query(query) + log().var_dump('Transliterated query', lambda: query.get_transliterated_query()) + words = query.extract_words(base_penalty=PENALTY_IN_TOKEN_BREAK[qmod.BREAK_WORD]) for row in await self.lookup_in_db(list(words.keys())): for trange in words[row.word_token]: @@ -240,8 +197,13 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer): else: query.add_token(trange, DB_TO_TOKEN_TYPE[row.type], token) - self.add_extra_tokens(query, parts) - self.rerank_tokens(query, parts) + self.add_extra_tokens(query) + for start, end, pc in self.postcode_parser.parse(query): + query.add_token(qmod.TokenRange(start, end), + qmod.TOKEN_POSTCODE, + ICUToken(penalty=0.1, token=0, count=1, addr_count=1, + lookup_word=pc, word_token=pc, info=None)) + self.rerank_tokens(query) log().table_dump('Word tokens', _dump_word_tokens(query)) @@ -254,16 +216,9 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer): """ return cast(str, self.normalizer.transliterate(text)).strip('-: ') - def split_query(self, query: qmod.QueryStruct) -> Tuple[QueryParts, WordDict]: + def split_query(self, query: qmod.QueryStruct) -> None: """ Transliterate the phrases and split them into tokens. - - Returns the list of transliterated tokens together with their - normalized form and a dictionary of words for lookup together - with their position. """ - parts: QueryParts = [] - phrase_start = 0 - words: WordDict = defaultdict(list) for phrase in query.source: query.nodes[-1].ptype = phrase.ptype phrase_split = re.split('([ :-])', phrase.text) @@ -278,38 +233,42 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer): if trans: for term in trans.split(' '): if term: - parts.append(QueryPart(term, word, - PENALTY_IN_TOKEN_BREAK[qmod.BREAK_TOKEN])) - query.add_node(qmod.BREAK_TOKEN, phrase.ptype) - query.nodes[-1].btype = breakchar - parts[-1].penalty = PENALTY_IN_TOKEN_BREAK[breakchar] + query.add_node(qmod.BREAK_TOKEN, phrase.ptype, + PENALTY_IN_TOKEN_BREAK[qmod.BREAK_TOKEN], + term, word) + query.nodes[-1].adjust_break(breakchar, + PENALTY_IN_TOKEN_BREAK[breakchar]) - extract_words(parts, phrase_start, words) - - phrase_start = len(parts) - query.nodes[-1].btype = qmod.BREAK_END - - return parts, words + query.nodes[-1].adjust_break(qmod.BREAK_END, PENALTY_IN_TOKEN_BREAK[qmod.BREAK_END]) async def lookup_in_db(self, words: List[str]) -> 'sa.Result[Any]': """ Return the token information from the database for the given word tokens. + + This function excludes postcode tokens """ t = self.conn.t.meta.tables['word'] - return await self.conn.execute(t.select().where(t.c.word_token.in_(words))) + return await self.conn.execute(t.select() + .where(t.c.word_token.in_(words)) + .where(t.c.type != 'P')) - def add_extra_tokens(self, query: qmod.QueryStruct, parts: QueryParts) -> None: + def add_extra_tokens(self, query: qmod.QueryStruct) -> None: """ Add tokens to query that are not saved in the database. """ - for part, node, i in zip(parts, query.nodes, range(1000)): - if len(part.token) <= 4 and part.token.isdigit()\ - and not node.has_tokens(i+1, qmod.TOKEN_HOUSENUMBER): - query.add_token(qmod.TokenRange(i, i+1), qmod.TOKEN_HOUSENUMBER, + need_hnr = False + for i, node in enumerate(query.nodes): + is_full_token = node.btype not in (qmod.BREAK_TOKEN, qmod.BREAK_PART) + if need_hnr and is_full_token \ + and len(node.term_normalized) <= 4 and node.term_normalized.isdigit(): + query.add_token(qmod.TokenRange(i-1, i), qmod.TOKEN_HOUSENUMBER, ICUToken(penalty=0.5, token=0, - count=1, addr_count=1, lookup_word=part.token, - word_token=part.token, info=None)) + count=1, addr_count=1, + lookup_word=node.term_lookup, + word_token=node.term_lookup, info=None)) + + need_hnr = is_full_token and not node.has_tokens(i+1, qmod.TOKEN_HOUSENUMBER) - def rerank_tokens(self, query: qmod.QueryStruct, parts: QueryParts) -> None: + def rerank_tokens(self, query: qmod.QueryStruct) -> None: """ Add penalties to tokens that depend on presence of other token. """ for i, node, tlist in query.iter_token_lists(): @@ -326,28 +285,22 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer): if repl.end == tlist.end and repl.ttype != qmod.TOKEN_HOUSENUMBER: repl.add_penalty(0.5 - tlist.tokens[0].penalty) elif tlist.ttype not in (qmod.TOKEN_COUNTRY, qmod.TOKEN_PARTIAL): - norm = parts[i].normalized - for j in range(i + 1, tlist.end): - if node.btype != qmod.BREAK_TOKEN: - norm += ' ' + parts[j].normalized + norm = ' '.join(n.term_normalized for n in query.nodes[i + 1:tlist.end + 1] + if n.btype != qmod.BREAK_TOKEN) + if not norm: + # Can happen when the token only covers a partial term + norm = query.nodes[i + 1].term_normalized for token in tlist.tokens: cast(ICUToken, token).rematch(norm) -def _dump_transliterated(query: qmod.QueryStruct, parts: QueryParts) -> str: - out = query.nodes[0].btype - for node, part in zip(query.nodes[1:], parts): - out += part.token + node.btype - return out - - def _dump_word_tokens(query: qmod.QueryStruct) -> Iterator[List[Any]]: - yield ['type', 'token', 'word_token', 'lookup_word', 'penalty', 'count', 'info'] - for node in query.nodes: + yield ['type', 'from', 'to', 'token', 'word_token', 'lookup_word', 'penalty', 'count', 'info'] + for i, node in enumerate(query.nodes): for tlist in node.starting: for token in tlist.tokens: t = cast(ICUToken, token) - yield [tlist.ttype, t.token, t.word_token or '', + yield [tlist.ttype, str(i), str(tlist.end), t.token, t.word_token or '', t.lookup_word or '', t.penalty, t.count, t.info] @@ -355,7 +308,17 @@ async def create_query_analyzer(conn: SearchConnection) -> AbstractQueryAnalyzer """ Create and set up a new query analyzer for a database based on the ICU tokenizer. """ - out = ICUQueryAnalyzer(conn) - await out.setup() + async def _get_config() -> ICUAnalyzerConfig: + if 'word' not in conn.t.meta.tables: + sa.Table('word', conn.t.meta, + sa.Column('word_id', sa.Integer), + sa.Column('word_token', sa.Text, nullable=False), + sa.Column('type', sa.Text, nullable=False), + sa.Column('word', sa.Text), + sa.Column('info', Json)) + + return await ICUAnalyzerConfig.create(conn) + + config = await conn.get_cached_value('ICUTOK', 'config', _get_config) - return out + return ICUQueryAnalyzer(conn, config)