]> git.openstreetmap.org Git - nominatim.git/blobdiff - src/nominatim_api/search/icu_tokenizer.py
Merge remote-tracking branch 'upstream/master'
[nominatim.git] / src / nominatim_api / search / icu_tokenizer.py
index 3b85f26df7f1eea53323851e46d5479450d48342..ecc2c1c7f1c917ade41d5e7b2efba10b89caf96f 100644 (file)
@@ -8,7 +8,6 @@
 Implementation of query analysis for the ICU tokenizer.
 """
 from typing import Tuple, Dict, List, Optional, Iterator, Any, cast
 Implementation of query analysis for the ICU tokenizer.
 """
 from typing import Tuple, Dict, List, Optional, Iterator, Any, cast
-from collections import defaultdict
 import dataclasses
 import difflib
 import re
 import dataclasses
 import difflib
 import re
@@ -25,7 +24,9 @@ from ..connection import SearchConnection
 from ..logging import log
 from . import query as qmod
 from ..query_preprocessing.config import QueryConfig
 from ..logging import log
 from . import query as qmod
 from ..query_preprocessing.config import QueryConfig
+from ..query_preprocessing.base import QueryProcessingFunc
 from .query_analyzer_factory import AbstractQueryAnalyzer
 from .query_analyzer_factory import AbstractQueryAnalyzer
+from .postcode_parser import PostcodeParser
 
 
 DB_TO_TOKEN_TYPE = {
 
 
 DB_TO_TOKEN_TYPE = {
@@ -47,42 +48,6 @@ PENALTY_IN_TOKEN_BREAK = {
 }
 
 
 }
 
 
-@dataclasses.dataclass
-class QueryPart:
-    """ Normalized and transliterated form of a single term in the query.
-
-        When the term came out of a split during the transliteration,
-        the normalized string is the full word before transliteration.
-        Check the subsequent break type to figure out if the word is
-        continued.
-
-        Penalty is the break penalty for the break following the token.
-    """
-    token: str
-    normalized: str
-    penalty: float
-
-
-QueryParts = List[QueryPart]
-WordDict = Dict[str, List[qmod.TokenRange]]
-
-
-def extract_words(terms: List[QueryPart], start: int,  words: WordDict) -> None:
-    """ Add all combinations of words in the terms list after the
-        given position to the word list.
-    """
-    total = len(terms)
-    base_penalty = PENALTY_IN_TOKEN_BREAK[qmod.BREAK_WORD]
-    for first in range(start, total):
-        word = terms[first].token
-        penalty = base_penalty
-        words[word].append(qmod.TokenRange(first, first + 1, penalty=penalty))
-        for last in range(first + 1, min(first + 20, total)):
-            word = ' '.join((word, terms[last].token))
-            penalty += terms[last - 1].penalty
-            words[word].append(qmod.TokenRange(first, last + 1, penalty=penalty))
-
-
 @dataclasses.dataclass
 class ICUToken(qmod.Token):
     """ Specialised token for ICU tokenizer.
 @dataclasses.dataclass
 class ICUToken(qmod.Token):
     """ Specialised token for ICU tokenizer.
@@ -148,60 +113,51 @@ class ICUToken(qmod.Token):
                         addr_count=max(1, addr_count))
 
 
                         addr_count=max(1, addr_count))
 
 
-class ICUQueryAnalyzer(AbstractQueryAnalyzer):
-    """ Converter for query strings into a tokenized query
-        using the tokens created by a ICU tokenizer.
-    """
-    def __init__(self, conn: SearchConnection) -> None:
-        self.conn = conn
-
-    async def setup(self) -> None:
-        """ Set up static data structures needed for the analysis.
-        """
-        async def _make_normalizer() -> Any:
-            rules = await self.conn.get_property('tokenizer_import_normalisation')
-            return Transliterator.createFromRules("normalization", rules)
-
-        self.normalizer = await self.conn.get_cached_value('ICUTOK', 'normalizer',
-                                                           _make_normalizer)
-
-        async def _make_transliterator() -> Any:
-            rules = await self.conn.get_property('tokenizer_import_transliteration')
-            return Transliterator.createFromRules("transliteration", rules)
-
-        self.transliterator = await self.conn.get_cached_value('ICUTOK', 'transliterator',
-                                                               _make_transliterator)
-
-        await self._setup_preprocessing()
-
-        if 'word' not in self.conn.t.meta.tables:
-            sa.Table('word', self.conn.t.meta,
-                     sa.Column('word_id', sa.Integer),
-                     sa.Column('word_token', sa.Text, nullable=False),
-                     sa.Column('type', sa.Text, nullable=False),
-                     sa.Column('word', sa.Text),
-                     sa.Column('info', Json))
+@dataclasses.dataclass
+class ICUAnalyzerConfig:
+    postcode_parser: PostcodeParser
+    normalizer: Transliterator
+    transliterator: Transliterator
+    preprocessors: List[QueryProcessingFunc]
 
 
-    async def _setup_preprocessing(self) -> None:
-        """ Load the rules for preprocessing and set up the handlers.
-        """
+    @staticmethod
+    async def create(conn: SearchConnection) -> 'ICUAnalyzerConfig':
+        rules = await conn.get_property('tokenizer_import_normalisation')
+        normalizer = Transliterator.createFromRules("normalization", rules)
 
 
-        rules = self.conn.config.load_sub_configuration('icu_tokenizer.yaml',
-                                                        config='TOKENIZER_CONFIG')
-        preprocessing_rules = rules.get('query-preprocessing', [])
+        rules = await conn.get_property('tokenizer_import_transliteration')
+        transliterator = Transliterator.createFromRules("transliteration", rules)
 
 
-        self.preprocessors = []
+        preprocessing_rules = conn.config.load_sub_configuration('icu_tokenizer.yaml',
+                                                                 config='TOKENIZER_CONFIG')\
+                                         .get('query-preprocessing', [])
 
 
+        preprocessors: List[QueryProcessingFunc] = []
         for func in preprocessing_rules:
             if 'step' not in func:
                 raise UsageError("Preprocessing rule is missing the 'step' attribute.")
             if not isinstance(func['step'], str):
                 raise UsageError("'step' attribute must be a simple string.")
 
         for func in preprocessing_rules:
             if 'step' not in func:
                 raise UsageError("Preprocessing rule is missing the 'step' attribute.")
             if not isinstance(func['step'], str):
                 raise UsageError("'step' attribute must be a simple string.")
 
-            module = self.conn.config.load_plugin_module(
+            module = conn.config.load_plugin_module(
                         func['step'], 'nominatim_api.query_preprocessing')
                         func['step'], 'nominatim_api.query_preprocessing')
-            self.preprocessors.append(
-                module.create(QueryConfig(func).set_normalizer(self.normalizer)))
+            preprocessors.append(
+                module.create(QueryConfig(func).set_normalizer(normalizer)))
+
+        return ICUAnalyzerConfig(PostcodeParser(conn.config),
+                                 normalizer, transliterator, preprocessors)
+
+
+class ICUQueryAnalyzer(AbstractQueryAnalyzer):
+    """ Converter for query strings into a tokenized query
+        using the tokens created by a ICU tokenizer.
+    """
+    def __init__(self, conn: SearchConnection, config: ICUAnalyzerConfig) -> None:
+        self.conn = conn
+        self.postcode_parser = config.postcode_parser
+        self.normalizer = config.normalizer
+        self.transliterator = config.transliterator
+        self.preprocessors = config.preprocessors
 
     async def analyze_query(self, phrases: List[qmod.Phrase]) -> qmod.QueryStruct:
         """ Analyze the given list of phrases and return the
 
     async def analyze_query(self, phrases: List[qmod.Phrase]) -> qmod.QueryStruct:
         """ Analyze the given list of phrases and return the
@@ -222,8 +178,9 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
         if not query.source:
             return query
 
         if not query.source:
             return query
 
-        parts, words = self.split_query(query)
-        log().var_dump('Transliterated query', lambda: _dump_transliterated(query, parts))
+        self.split_query(query)
+        log().var_dump('Transliterated query', lambda: query.get_transliterated_query())
+        words = query.extract_words(base_penalty=PENALTY_IN_TOKEN_BREAK[qmod.BREAK_WORD])
 
         for row in await self.lookup_in_db(list(words.keys())):
             for trange in words[row.word_token]:
 
         for row in await self.lookup_in_db(list(words.keys())):
             for trange in words[row.word_token]:
@@ -240,8 +197,13 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
                 else:
                     query.add_token(trange, DB_TO_TOKEN_TYPE[row.type], token)
 
                 else:
                     query.add_token(trange, DB_TO_TOKEN_TYPE[row.type], token)
 
-        self.add_extra_tokens(query, parts)
-        self.rerank_tokens(query, parts)
+        self.add_extra_tokens(query)
+        for start, end, pc in self.postcode_parser.parse(query):
+            query.add_token(qmod.TokenRange(start, end),
+                            qmod.TOKEN_POSTCODE,
+                            ICUToken(penalty=0.1, token=0, count=1, addr_count=1,
+                                     lookup_word=pc, word_token=pc, info=None))
+        self.rerank_tokens(query)
 
         log().table_dump('Word tokens', _dump_word_tokens(query))
 
 
         log().table_dump('Word tokens', _dump_word_tokens(query))
 
@@ -254,16 +216,9 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
         """
         return cast(str, self.normalizer.transliterate(text)).strip('-: ')
 
         """
         return cast(str, self.normalizer.transliterate(text)).strip('-: ')
 
-    def split_query(self, query: qmod.QueryStruct) -> Tuple[QueryParts, WordDict]:
+    def split_query(self, query: qmod.QueryStruct) -> None:
         """ Transliterate the phrases and split them into tokens.
         """ Transliterate the phrases and split them into tokens.
-
-            Returns the list of transliterated tokens together with their
-            normalized form and a dictionary of words for lookup together
-            with their position.
         """
         """
-        parts: QueryParts = []
-        phrase_start = 0
-        words: WordDict = defaultdict(list)
         for phrase in query.source:
             query.nodes[-1].ptype = phrase.ptype
             phrase_split = re.split('([ :-])', phrase.text)
         for phrase in query.source:
             query.nodes[-1].ptype = phrase.ptype
             phrase_split = re.split('([ :-])', phrase.text)
@@ -278,38 +233,42 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
                 if trans:
                     for term in trans.split(' '):
                         if term:
                 if trans:
                     for term in trans.split(' '):
                         if term:
-                            parts.append(QueryPart(term, word,
-                                                   PENALTY_IN_TOKEN_BREAK[qmod.BREAK_TOKEN]))
-                            query.add_node(qmod.BREAK_TOKEN, phrase.ptype)
-                    query.nodes[-1].btype = breakchar
-                    parts[-1].penalty = PENALTY_IN_TOKEN_BREAK[breakchar]
+                            query.add_node(qmod.BREAK_TOKEN, phrase.ptype,
+                                           PENALTY_IN_TOKEN_BREAK[qmod.BREAK_TOKEN],
+                                           term, word)
+                    query.nodes[-1].adjust_break(breakchar,
+                                                 PENALTY_IN_TOKEN_BREAK[breakchar])
 
 
-            extract_words(parts, phrase_start, words)
-
-            phrase_start = len(parts)
-        query.nodes[-1].btype = qmod.BREAK_END
-
-        return parts, words
+        query.nodes[-1].adjust_break(qmod.BREAK_END, PENALTY_IN_TOKEN_BREAK[qmod.BREAK_END])
 
     async def lookup_in_db(self, words: List[str]) -> 'sa.Result[Any]':
         """ Return the token information from the database for the
             given word tokens.
 
     async def lookup_in_db(self, words: List[str]) -> 'sa.Result[Any]':
         """ Return the token information from the database for the
             given word tokens.
+
+            This function excludes postcode tokens
         """
         t = self.conn.t.meta.tables['word']
         """
         t = self.conn.t.meta.tables['word']
-        return await self.conn.execute(t.select().where(t.c.word_token.in_(words)))
+        return await self.conn.execute(t.select()
+                                        .where(t.c.word_token.in_(words))
+                                        .where(t.c.type != 'P'))
 
 
-    def add_extra_tokens(self, query: qmod.QueryStruct, parts: QueryParts) -> None:
+    def add_extra_tokens(self, query: qmod.QueryStruct) -> None:
         """ Add tokens to query that are not saved in the database.
         """
         """ Add tokens to query that are not saved in the database.
         """
-        for part, node, i in zip(parts, query.nodes, range(1000)):
-            if len(part.token) <= 4 and part.token.isdigit()\
-               and not node.has_tokens(i+1, qmod.TOKEN_HOUSENUMBER):
-                query.add_token(qmod.TokenRange(i, i+1), qmod.TOKEN_HOUSENUMBER,
+        need_hnr = False
+        for i, node in enumerate(query.nodes):
+            is_full_token = node.btype not in (qmod.BREAK_TOKEN, qmod.BREAK_PART)
+            if need_hnr and is_full_token \
+                    and len(node.term_normalized) <= 4 and node.term_normalized.isdigit():
+                query.add_token(qmod.TokenRange(i-1, i), qmod.TOKEN_HOUSENUMBER,
                                 ICUToken(penalty=0.5, token=0,
                                 ICUToken(penalty=0.5, token=0,
-                                         count=1, addr_count=1, lookup_word=part.token,
-                                         word_token=part.token, info=None))
+                                         count=1, addr_count=1,
+                                         lookup_word=node.term_lookup,
+                                         word_token=node.term_lookup, info=None))
+
+            need_hnr = is_full_token and not node.has_tokens(i+1, qmod.TOKEN_HOUSENUMBER)
 
 
-    def rerank_tokens(self, query: qmod.QueryStruct, parts: QueryParts) -> None:
+    def rerank_tokens(self, query: qmod.QueryStruct) -> None:
         """ Add penalties to tokens that depend on presence of other token.
         """
         for i, node, tlist in query.iter_token_lists():
         """ Add penalties to tokens that depend on presence of other token.
         """
         for i, node, tlist in query.iter_token_lists():
@@ -326,28 +285,22 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
                         if repl.end == tlist.end and repl.ttype != qmod.TOKEN_HOUSENUMBER:
                             repl.add_penalty(0.5 - tlist.tokens[0].penalty)
             elif tlist.ttype not in (qmod.TOKEN_COUNTRY, qmod.TOKEN_PARTIAL):
                         if repl.end == tlist.end and repl.ttype != qmod.TOKEN_HOUSENUMBER:
                             repl.add_penalty(0.5 - tlist.tokens[0].penalty)
             elif tlist.ttype not in (qmod.TOKEN_COUNTRY, qmod.TOKEN_PARTIAL):
-                norm = parts[i].normalized
-                for j in range(i + 1, tlist.end):
-                    if node.btype != qmod.BREAK_TOKEN:
-                        norm += '  ' + parts[j].normalized
+                norm = ' '.join(n.term_normalized for n in query.nodes[i + 1:tlist.end + 1]
+                                if n.btype != qmod.BREAK_TOKEN)
+                if not norm:
+                    # Can happen when the token only covers a partial term
+                    norm = query.nodes[i + 1].term_normalized
                 for token in tlist.tokens:
                     cast(ICUToken, token).rematch(norm)
 
 
                 for token in tlist.tokens:
                     cast(ICUToken, token).rematch(norm)
 
 
-def _dump_transliterated(query: qmod.QueryStruct, parts: QueryParts) -> str:
-    out = query.nodes[0].btype
-    for node, part in zip(query.nodes[1:], parts):
-        out += part.token + node.btype
-    return out
-
-
 def _dump_word_tokens(query: qmod.QueryStruct) -> Iterator[List[Any]]:
 def _dump_word_tokens(query: qmod.QueryStruct) -> Iterator[List[Any]]:
-    yield ['type', 'token', 'word_token', 'lookup_word', 'penalty', 'count', 'info']
-    for node in query.nodes:
+    yield ['type', 'from', 'to', 'token', 'word_token', 'lookup_word', 'penalty', 'count', 'info']
+    for i, node in enumerate(query.nodes):
         for tlist in node.starting:
             for token in tlist.tokens:
                 t = cast(ICUToken, token)
         for tlist in node.starting:
             for token in tlist.tokens:
                 t = cast(ICUToken, token)
-                yield [tlist.ttype, t.token, t.word_token or '',
+                yield [tlist.ttype, str(i), str(tlist.end), t.token, t.word_token or '',
                        t.lookup_word or '', t.penalty, t.count, t.info]
 
 
                        t.lookup_word or '', t.penalty, t.count, t.info]
 
 
@@ -355,7 +308,17 @@ async def create_query_analyzer(conn: SearchConnection) -> AbstractQueryAnalyzer
     """ Create and set up a new query analyzer for a database based
         on the ICU tokenizer.
     """
     """ Create and set up a new query analyzer for a database based
         on the ICU tokenizer.
     """
-    out = ICUQueryAnalyzer(conn)
-    await out.setup()
+    async def _get_config() -> ICUAnalyzerConfig:
+        if 'word' not in conn.t.meta.tables:
+            sa.Table('word', conn.t.meta,
+                     sa.Column('word_id', sa.Integer),
+                     sa.Column('word_token', sa.Text, nullable=False),
+                     sa.Column('type', sa.Text, nullable=False),
+                     sa.Column('word', sa.Text),
+                     sa.Column('info', Json))
+
+        return await ICUAnalyzerConfig.create(conn)
+
+    config = await conn.get_cached_value('ICUTOK', 'config', _get_config)
 
 
-    return out
+    return ICUQueryAnalyzer(conn, config)