]> git.openstreetmap.org Git - nominatim.git/blobdiff - src/nominatim_api/search/icu_tokenizer.py
Merge remote-tracking branch 'upstream/master'
[nominatim.git] / src / nominatim_api / search / icu_tokenizer.py
index 3b85f26df7f1eea53323851e46d5479450d48342..ecc2c1c7f1c917ade41d5e7b2efba10b89caf96f 100644 (file)
@@ -8,7 +8,6 @@
 Implementation of query analysis for the ICU tokenizer.
 """
 from typing import Tuple, Dict, List, Optional, Iterator, Any, cast
-from collections import defaultdict
 import dataclasses
 import difflib
 import re
@@ -25,7 +24,9 @@ from ..connection import SearchConnection
 from ..logging import log
 from . import query as qmod
 from ..query_preprocessing.config import QueryConfig
+from ..query_preprocessing.base import QueryProcessingFunc
 from .query_analyzer_factory import AbstractQueryAnalyzer
+from .postcode_parser import PostcodeParser
 
 
 DB_TO_TOKEN_TYPE = {
@@ -47,42 +48,6 @@ PENALTY_IN_TOKEN_BREAK = {
 }
 
 
-@dataclasses.dataclass
-class QueryPart:
-    """ Normalized and transliterated form of a single term in the query.
-
-        When the term came out of a split during the transliteration,
-        the normalized string is the full word before transliteration.
-        Check the subsequent break type to figure out if the word is
-        continued.
-
-        Penalty is the break penalty for the break following the token.
-    """
-    token: str
-    normalized: str
-    penalty: float
-
-
-QueryParts = List[QueryPart]
-WordDict = Dict[str, List[qmod.TokenRange]]
-
-
-def extract_words(terms: List[QueryPart], start: int,  words: WordDict) -> None:
-    """ Add all combinations of words in the terms list after the
-        given position to the word list.
-    """
-    total = len(terms)
-    base_penalty = PENALTY_IN_TOKEN_BREAK[qmod.BREAK_WORD]
-    for first in range(start, total):
-        word = terms[first].token
-        penalty = base_penalty
-        words[word].append(qmod.TokenRange(first, first + 1, penalty=penalty))
-        for last in range(first + 1, min(first + 20, total)):
-            word = ' '.join((word, terms[last].token))
-            penalty += terms[last - 1].penalty
-            words[word].append(qmod.TokenRange(first, last + 1, penalty=penalty))
-
-
 @dataclasses.dataclass
 class ICUToken(qmod.Token):
     """ Specialised token for ICU tokenizer.
@@ -148,60 +113,51 @@ class ICUToken(qmod.Token):
                         addr_count=max(1, addr_count))
 
 
-class ICUQueryAnalyzer(AbstractQueryAnalyzer):
-    """ Converter for query strings into a tokenized query
-        using the tokens created by a ICU tokenizer.
-    """
-    def __init__(self, conn: SearchConnection) -> None:
-        self.conn = conn
-
-    async def setup(self) -> None:
-        """ Set up static data structures needed for the analysis.
-        """
-        async def _make_normalizer() -> Any:
-            rules = await self.conn.get_property('tokenizer_import_normalisation')
-            return Transliterator.createFromRules("normalization", rules)
-
-        self.normalizer = await self.conn.get_cached_value('ICUTOK', 'normalizer',
-                                                           _make_normalizer)
-
-        async def _make_transliterator() -> Any:
-            rules = await self.conn.get_property('tokenizer_import_transliteration')
-            return Transliterator.createFromRules("transliteration", rules)
-
-        self.transliterator = await self.conn.get_cached_value('ICUTOK', 'transliterator',
-                                                               _make_transliterator)
-
-        await self._setup_preprocessing()
-
-        if 'word' not in self.conn.t.meta.tables:
-            sa.Table('word', self.conn.t.meta,
-                     sa.Column('word_id', sa.Integer),
-                     sa.Column('word_token', sa.Text, nullable=False),
-                     sa.Column('type', sa.Text, nullable=False),
-                     sa.Column('word', sa.Text),
-                     sa.Column('info', Json))
+@dataclasses.dataclass
+class ICUAnalyzerConfig:
+    postcode_parser: PostcodeParser
+    normalizer: Transliterator
+    transliterator: Transliterator
+    preprocessors: List[QueryProcessingFunc]
 
-    async def _setup_preprocessing(self) -> None:
-        """ Load the rules for preprocessing and set up the handlers.
-        """
+    @staticmethod
+    async def create(conn: SearchConnection) -> 'ICUAnalyzerConfig':
+        rules = await conn.get_property('tokenizer_import_normalisation')
+        normalizer = Transliterator.createFromRules("normalization", rules)
 
-        rules = self.conn.config.load_sub_configuration('icu_tokenizer.yaml',
-                                                        config='TOKENIZER_CONFIG')
-        preprocessing_rules = rules.get('query-preprocessing', [])
+        rules = await conn.get_property('tokenizer_import_transliteration')
+        transliterator = Transliterator.createFromRules("transliteration", rules)
 
-        self.preprocessors = []
+        preprocessing_rules = conn.config.load_sub_configuration('icu_tokenizer.yaml',
+                                                                 config='TOKENIZER_CONFIG')\
+                                         .get('query-preprocessing', [])
 
+        preprocessors: List[QueryProcessingFunc] = []
         for func in preprocessing_rules:
             if 'step' not in func:
                 raise UsageError("Preprocessing rule is missing the 'step' attribute.")
             if not isinstance(func['step'], str):
                 raise UsageError("'step' attribute must be a simple string.")
 
-            module = self.conn.config.load_plugin_module(
+            module = conn.config.load_plugin_module(
                         func['step'], 'nominatim_api.query_preprocessing')
-            self.preprocessors.append(
-                module.create(QueryConfig(func).set_normalizer(self.normalizer)))
+            preprocessors.append(
+                module.create(QueryConfig(func).set_normalizer(normalizer)))
+
+        return ICUAnalyzerConfig(PostcodeParser(conn.config),
+                                 normalizer, transliterator, preprocessors)
+
+
+class ICUQueryAnalyzer(AbstractQueryAnalyzer):
+    """ Converter for query strings into a tokenized query
+        using the tokens created by a ICU tokenizer.
+    """
+    def __init__(self, conn: SearchConnection, config: ICUAnalyzerConfig) -> None:
+        self.conn = conn
+        self.postcode_parser = config.postcode_parser
+        self.normalizer = config.normalizer
+        self.transliterator = config.transliterator
+        self.preprocessors = config.preprocessors
 
     async def analyze_query(self, phrases: List[qmod.Phrase]) -> qmod.QueryStruct:
         """ Analyze the given list of phrases and return the
@@ -222,8 +178,9 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
         if not query.source:
             return query
 
-        parts, words = self.split_query(query)
-        log().var_dump('Transliterated query', lambda: _dump_transliterated(query, parts))
+        self.split_query(query)
+        log().var_dump('Transliterated query', lambda: query.get_transliterated_query())
+        words = query.extract_words(base_penalty=PENALTY_IN_TOKEN_BREAK[qmod.BREAK_WORD])
 
         for row in await self.lookup_in_db(list(words.keys())):
             for trange in words[row.word_token]:
@@ -240,8 +197,13 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
                 else:
                     query.add_token(trange, DB_TO_TOKEN_TYPE[row.type], token)
 
-        self.add_extra_tokens(query, parts)
-        self.rerank_tokens(query, parts)
+        self.add_extra_tokens(query)
+        for start, end, pc in self.postcode_parser.parse(query):
+            query.add_token(qmod.TokenRange(start, end),
+                            qmod.TOKEN_POSTCODE,
+                            ICUToken(penalty=0.1, token=0, count=1, addr_count=1,
+                                     lookup_word=pc, word_token=pc, info=None))
+        self.rerank_tokens(query)
 
         log().table_dump('Word tokens', _dump_word_tokens(query))
 
@@ -254,16 +216,9 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
         """
         return cast(str, self.normalizer.transliterate(text)).strip('-: ')
 
-    def split_query(self, query: qmod.QueryStruct) -> Tuple[QueryParts, WordDict]:
+    def split_query(self, query: qmod.QueryStruct) -> None:
         """ Transliterate the phrases and split them into tokens.
-
-            Returns the list of transliterated tokens together with their
-            normalized form and a dictionary of words for lookup together
-            with their position.
         """
-        parts: QueryParts = []
-        phrase_start = 0
-        words: WordDict = defaultdict(list)
         for phrase in query.source:
             query.nodes[-1].ptype = phrase.ptype
             phrase_split = re.split('([ :-])', phrase.text)
@@ -278,38 +233,42 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
                 if trans:
                     for term in trans.split(' '):
                         if term:
-                            parts.append(QueryPart(term, word,
-                                                   PENALTY_IN_TOKEN_BREAK[qmod.BREAK_TOKEN]))
-                            query.add_node(qmod.BREAK_TOKEN, phrase.ptype)
-                    query.nodes[-1].btype = breakchar
-                    parts[-1].penalty = PENALTY_IN_TOKEN_BREAK[breakchar]
+                            query.add_node(qmod.BREAK_TOKEN, phrase.ptype,
+                                           PENALTY_IN_TOKEN_BREAK[qmod.BREAK_TOKEN],
+                                           term, word)
+                    query.nodes[-1].adjust_break(breakchar,
+                                                 PENALTY_IN_TOKEN_BREAK[breakchar])
 
-            extract_words(parts, phrase_start, words)
-
-            phrase_start = len(parts)
-        query.nodes[-1].btype = qmod.BREAK_END
-
-        return parts, words
+        query.nodes[-1].adjust_break(qmod.BREAK_END, PENALTY_IN_TOKEN_BREAK[qmod.BREAK_END])
 
     async def lookup_in_db(self, words: List[str]) -> 'sa.Result[Any]':
         """ Return the token information from the database for the
             given word tokens.
+
+            This function excludes postcode tokens
         """
         t = self.conn.t.meta.tables['word']
-        return await self.conn.execute(t.select().where(t.c.word_token.in_(words)))
+        return await self.conn.execute(t.select()
+                                        .where(t.c.word_token.in_(words))
+                                        .where(t.c.type != 'P'))
 
-    def add_extra_tokens(self, query: qmod.QueryStruct, parts: QueryParts) -> None:
+    def add_extra_tokens(self, query: qmod.QueryStruct) -> None:
         """ Add tokens to query that are not saved in the database.
         """
-        for part, node, i in zip(parts, query.nodes, range(1000)):
-            if len(part.token) <= 4 and part.token.isdigit()\
-               and not node.has_tokens(i+1, qmod.TOKEN_HOUSENUMBER):
-                query.add_token(qmod.TokenRange(i, i+1), qmod.TOKEN_HOUSENUMBER,
+        need_hnr = False
+        for i, node in enumerate(query.nodes):
+            is_full_token = node.btype not in (qmod.BREAK_TOKEN, qmod.BREAK_PART)
+            if need_hnr and is_full_token \
+                    and len(node.term_normalized) <= 4 and node.term_normalized.isdigit():
+                query.add_token(qmod.TokenRange(i-1, i), qmod.TOKEN_HOUSENUMBER,
                                 ICUToken(penalty=0.5, token=0,
-                                         count=1, addr_count=1, lookup_word=part.token,
-                                         word_token=part.token, info=None))
+                                         count=1, addr_count=1,
+                                         lookup_word=node.term_lookup,
+                                         word_token=node.term_lookup, info=None))
+
+            need_hnr = is_full_token and not node.has_tokens(i+1, qmod.TOKEN_HOUSENUMBER)
 
-    def rerank_tokens(self, query: qmod.QueryStruct, parts: QueryParts) -> None:
+    def rerank_tokens(self, query: qmod.QueryStruct) -> None:
         """ Add penalties to tokens that depend on presence of other token.
         """
         for i, node, tlist in query.iter_token_lists():
@@ -326,28 +285,22 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
                         if repl.end == tlist.end and repl.ttype != qmod.TOKEN_HOUSENUMBER:
                             repl.add_penalty(0.5 - tlist.tokens[0].penalty)
             elif tlist.ttype not in (qmod.TOKEN_COUNTRY, qmod.TOKEN_PARTIAL):
-                norm = parts[i].normalized
-                for j in range(i + 1, tlist.end):
-                    if node.btype != qmod.BREAK_TOKEN:
-                        norm += '  ' + parts[j].normalized
+                norm = ' '.join(n.term_normalized for n in query.nodes[i + 1:tlist.end + 1]
+                                if n.btype != qmod.BREAK_TOKEN)
+                if not norm:
+                    # Can happen when the token only covers a partial term
+                    norm = query.nodes[i + 1].term_normalized
                 for token in tlist.tokens:
                     cast(ICUToken, token).rematch(norm)
 
 
-def _dump_transliterated(query: qmod.QueryStruct, parts: QueryParts) -> str:
-    out = query.nodes[0].btype
-    for node, part in zip(query.nodes[1:], parts):
-        out += part.token + node.btype
-    return out
-
-
 def _dump_word_tokens(query: qmod.QueryStruct) -> Iterator[List[Any]]:
-    yield ['type', 'token', 'word_token', 'lookup_word', 'penalty', 'count', 'info']
-    for node in query.nodes:
+    yield ['type', 'from', 'to', 'token', 'word_token', 'lookup_word', 'penalty', 'count', 'info']
+    for i, node in enumerate(query.nodes):
         for tlist in node.starting:
             for token in tlist.tokens:
                 t = cast(ICUToken, token)
-                yield [tlist.ttype, t.token, t.word_token or '',
+                yield [tlist.ttype, str(i), str(tlist.end), t.token, t.word_token or '',
                        t.lookup_word or '', t.penalty, t.count, t.info]
 
 
@@ -355,7 +308,17 @@ async def create_query_analyzer(conn: SearchConnection) -> AbstractQueryAnalyzer
     """ Create and set up a new query analyzer for a database based
         on the ICU tokenizer.
     """
-    out = ICUQueryAnalyzer(conn)
-    await out.setup()
+    async def _get_config() -> ICUAnalyzerConfig:
+        if 'word' not in conn.t.meta.tables:
+            sa.Table('word', conn.t.meta,
+                     sa.Column('word_id', sa.Integer),
+                     sa.Column('word_token', sa.Text, nullable=False),
+                     sa.Column('type', sa.Text, nullable=False),
+                     sa.Column('word', sa.Text),
+                     sa.Column('info', Json))
+
+        return await ICUAnalyzerConfig.create(conn)
+
+    config = await conn.get_cached_value('ICUTOK', 'config', _get_config)
 
-    return out
+    return ICUQueryAnalyzer(conn, config)