]> git.openstreetmap.org Git - nominatim.git/blobdiff - src/nominatim_api/search/icu_tokenizer.py
add SOFT_PHRASE break and enable parsing
[nominatim.git] / src / nominatim_api / search / icu_tokenizer.py
index 971e95beec1a6935b58e7c9cc4879d9797a73f1b..d52614fdaf4863b0aaa2e2721659988d9c03470d 100644 (file)
@@ -11,17 +11,21 @@ from typing import Tuple, Dict, List, Optional, NamedTuple, Iterator, Any, cast
 from collections import defaultdict
 import dataclasses
 import difflib
 from collections import defaultdict
 import dataclasses
 import difflib
+import re
+from itertools import zip_longest
 
 from icu import Transliterator
 
 import sqlalchemy as sa
 
 
 from icu import Transliterator
 
 import sqlalchemy as sa
 
+from ..errors import UsageError
 from ..typing import SaRow
 from ..sql.sqlalchemy_types import Json
 from ..connection import SearchConnection
 from ..logging import log
 from ..typing import SaRow
 from ..sql.sqlalchemy_types import Json
 from ..connection import SearchConnection
 from ..logging import log
-from ..search import query as qmod
-from ..search.query_analyzer_factory import AbstractQueryAnalyzer
+from . import query as qmod
+from ..query_preprocessing.config import QueryConfig
+from .query_analyzer_factory import AbstractQueryAnalyzer
 
 
 DB_TO_TOKEN_TYPE = {
 
 
 DB_TO_TOKEN_TYPE = {
@@ -48,6 +52,7 @@ class QueryPart(NamedTuple):
 QueryParts = List[QueryPart]
 WordDict = Dict[str, List[qmod.TokenRange]]
 
 QueryParts = List[QueryPart]
 WordDict = Dict[str, List[qmod.TokenRange]]
 
+
 def yield_words(terms: List[QueryPart], start: int) -> Iterator[Tuple[str, qmod.TokenRange]]:
     """ Return all combinations of words in the terms list after the
         given position.
 def yield_words(terms: List[QueryPart], start: int) -> Iterator[Tuple[str, qmod.TokenRange]]:
     """ Return all combinations of words in the terms list after the
         given position.
@@ -72,7 +77,6 @@ class ICUToken(qmod.Token):
         assert self.info
         return self.info.get('class', ''), self.info.get('type', '')
 
         assert self.info
         return self.info.get('class', ''), self.info.get('type', '')
 
-
     def rematch(self, norm: str) -> None:
         """ Check how well the token matches the given normalized string
             and add a penalty, if necessary.
     def rematch(self, norm: str) -> None:
         """ Check how well the token matches the given normalized string
             and add a penalty, if necessary.
@@ -91,7 +95,6 @@ class ICUToken(qmod.Token):
                 distance += abs((ato-afrom) - (bto-bfrom))
         self.penalty += (distance/len(self.lookup_word))
 
                 distance += abs((ato-afrom) - (bto-bfrom))
         self.penalty += (distance/len(self.lookup_word))
 
-
     @staticmethod
     def from_db_row(row: SaRow) -> 'ICUToken':
         """ Create a ICUToken from the row of the word table.
     @staticmethod
     def from_db_row(row: SaRow) -> 'ICUToken':
         """ Create a ICUToken from the row of the word table.
@@ -123,21 +126,18 @@ class ICUToken(qmod.Token):
             lookup_word = row.word_token
 
         return ICUToken(penalty=penalty, token=row.word_id, count=max(1, count),
             lookup_word = row.word_token
 
         return ICUToken(penalty=penalty, token=row.word_id, count=max(1, count),
-                        lookup_word=lookup_word, is_indexed=True,
+                        lookup_word=lookup_word,
                         word_token=row.word_token, info=row.info,
                         addr_count=max(1, addr_count))
 
 
                         word_token=row.word_token, info=row.info,
                         addr_count=max(1, addr_count))
 
 
-
 class ICUQueryAnalyzer(AbstractQueryAnalyzer):
     """ Converter for query strings into a tokenized query
         using the tokens created by a ICU tokenizer.
     """
 class ICUQueryAnalyzer(AbstractQueryAnalyzer):
     """ Converter for query strings into a tokenized query
         using the tokens created by a ICU tokenizer.
     """
-
     def __init__(self, conn: SearchConnection) -> None:
         self.conn = conn
 
     def __init__(self, conn: SearchConnection) -> None:
         self.conn = conn
 
-
     async def setup(self) -> None:
         """ Set up static data structures needed for the analysis.
         """
     async def setup(self) -> None:
         """ Set up static data structures needed for the analysis.
         """
@@ -155,6 +155,8 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
         self.transliterator = await self.conn.get_cached_value('ICUTOK', 'transliterator',
                                                                _make_transliterator)
 
         self.transliterator = await self.conn.get_cached_value('ICUTOK', 'transliterator',
                                                                _make_transliterator)
 
+        await self._setup_preprocessing()
+
         if 'word' not in self.conn.t.meta.tables:
             sa.Table('word', self.conn.t.meta,
                      sa.Column('word_id', sa.Integer),
         if 'word' not in self.conn.t.meta.tables:
             sa.Table('word', self.conn.t.meta,
                      sa.Column('word_id', sa.Integer),
@@ -163,16 +165,36 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
                      sa.Column('word', sa.Text),
                      sa.Column('info', Json))
 
                      sa.Column('word', sa.Text),
                      sa.Column('info', Json))
 
+    async def _setup_preprocessing(self) -> None:
+        """ Load the rules for preprocessing and set up the handlers.
+        """
+
+        rules = self.conn.config.load_sub_configuration('icu_tokenizer.yaml',
+                                                        config='TOKENIZER_CONFIG')
+        preprocessing_rules = rules.get('query-preprocessing', [])
+
+        self.preprocessors = []
+
+        for func in preprocessing_rules:
+            if 'step' not in func:
+                raise UsageError("Preprocessing rule is missing the 'step' attribute.")
+            if not isinstance(func['step'], str):
+                raise UsageError("'step' attribute must be a simple string.")
+
+            module = self.conn.config.load_plugin_module(
+                        func['step'], 'nominatim_api.query_preprocessing')
+            self.preprocessors.append(
+                module.create(QueryConfig(func).set_normalizer(self.normalizer)))
 
     async def analyze_query(self, phrases: List[qmod.Phrase]) -> qmod.QueryStruct:
         """ Analyze the given list of phrases and return the
             tokenized query.
         """
         log().section('Analyze query (using ICU tokenizer)')
 
     async def analyze_query(self, phrases: List[qmod.Phrase]) -> qmod.QueryStruct:
         """ Analyze the given list of phrases and return the
             tokenized query.
         """
         log().section('Analyze query (using ICU tokenizer)')
-        normalized = list(filter(lambda p: p.text,
-                                 (qmod.Phrase(p.ptype, self.normalize_text(p.text))
-                                  for p in phrases)))
-        query = qmod.QueryStruct(normalized)
+        for func in self.preprocessors:
+            phrases = func(phrases)
+        query = qmod.QueryStruct(phrases)
+
         log().var_dump('Normalized query', query.source)
         if not query.source:
             return query
         log().var_dump('Normalized query', query.source)
         if not query.source:
             return query
@@ -202,7 +224,6 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
 
         return query
 
 
         return query
 
-
     def normalize_text(self, text: str) -> str:
         """ Bring the given text into a normalized form. That is the
             standardized form search will work with. All information removed
     def normalize_text(self, text: str) -> str:
         """ Bring the given text into a normalized form. That is the
             standardized form search will work with. All information removed
@@ -210,7 +231,6 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
         """
         return cast(str, self.normalizer.transliterate(text))
 
         """
         return cast(str, self.normalizer.transliterate(text))
 
-
     def split_query(self, query: qmod.QueryStruct) -> Tuple[QueryParts, WordDict]:
         """ Transliterate the phrases and split them into tokens.
 
     def split_query(self, query: qmod.QueryStruct) -> Tuple[QueryParts, WordDict]:
         """ Transliterate the phrases and split them into tokens.
 
@@ -224,16 +244,22 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
         wordnr = 0
         for phrase in query.source:
             query.nodes[-1].ptype = phrase.ptype
         wordnr = 0
         for phrase in query.source:
             query.nodes[-1].ptype = phrase.ptype
-            for word in phrase.text.split(' '):
+            phrase_split = re.split('([ :-])', phrase.text)
+            # The zip construct will give us the pairs of word/break from
+            # the regular expression split. As the split array ends on the
+            # final word, we simply use the fillvalue to even out the list and
+            # add the phrase break at the end.
+            for word, breakchar in zip_longest(*[iter(phrase_split)]*2, fillvalue=','):
+                if not word:
+                    continue
                 trans = self.transliterator.transliterate(word)
                 if trans:
                     for term in trans.split(' '):
                         if term:
                             parts.append(QueryPart(term, word, wordnr))
                             query.add_node(qmod.BreakType.TOKEN, phrase.ptype)
                 trans = self.transliterator.transliterate(word)
                 if trans:
                     for term in trans.split(' '):
                         if term:
                             parts.append(QueryPart(term, word, wordnr))
                             query.add_node(qmod.BreakType.TOKEN, phrase.ptype)
-                    query.nodes[-1].btype = qmod.BreakType.WORD
+                    query.nodes[-1].btype = qmod.BreakType(breakchar)
                 wordnr += 1
                 wordnr += 1
-            query.nodes[-1].btype = qmod.BreakType.PHRASE
 
             for word, wrange in yield_words(parts, phrase_start):
                 words[word].append(wrange)
 
             for word, wrange in yield_words(parts, phrase_start):
                 words[word].append(wrange)
@@ -243,7 +269,6 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
 
         return parts, words
 
 
         return parts, words
 
-
     async def lookup_in_db(self, words: List[str]) -> 'sa.Result[Any]':
         """ Return the token information from the database for the
             given word tokens.
     async def lookup_in_db(self, words: List[str]) -> 'sa.Result[Any]':
         """ Return the token information from the database for the
             given word tokens.
@@ -251,7 +276,6 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
         t = self.conn.t.meta.tables['word']
         return await self.conn.execute(t.select().where(t.c.word_token.in_(words)))
 
         t = self.conn.t.meta.tables['word']
         return await self.conn.execute(t.select().where(t.c.word_token.in_(words)))
 
-
     def add_extra_tokens(self, query: qmod.QueryStruct, parts: QueryParts) -> None:
         """ Add tokens to query that are not saved in the database.
         """
     def add_extra_tokens(self, query: qmod.QueryStruct, parts: QueryParts) -> None:
         """ Add tokens to query that are not saved in the database.
         """
@@ -259,8 +283,9 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
             if len(part.token) <= 4 and part[0].isdigit()\
                and not node.has_tokens(i+1, qmod.TokenType.HOUSENUMBER):
                 query.add_token(qmod.TokenRange(i, i+1), qmod.TokenType.HOUSENUMBER,
             if len(part.token) <= 4 and part[0].isdigit()\
                and not node.has_tokens(i+1, qmod.TokenType.HOUSENUMBER):
                 query.add_token(qmod.TokenRange(i, i+1), qmod.TokenType.HOUSENUMBER,
-                                ICUToken(0.5, 0, 1, 1, part.token, True, part.token, None))
-
+                                ICUToken(penalty=0.5, token=0,
+                                         count=1, addr_count=1, lookup_word=part.token,
+                                         word_token=part.token, info=None))
 
     def rerank_tokens(self, query: qmod.QueryStruct, parts: QueryParts) -> None:
         """ Add penalties to tokens that depend on presence of other token.
 
     def rerank_tokens(self, query: qmod.QueryStruct, parts: QueryParts) -> None:
         """ Add penalties to tokens that depend on presence of other token.
@@ -272,8 +297,8 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
                        and (repl.ttype != qmod.TokenType.HOUSENUMBER
                             or len(tlist.tokens[0].lookup_word) > 4):
                         repl.add_penalty(0.39)
                        and (repl.ttype != qmod.TokenType.HOUSENUMBER
                             or len(tlist.tokens[0].lookup_word) > 4):
                         repl.add_penalty(0.39)
-            elif tlist.ttype == qmod.TokenType.HOUSENUMBER \
-                 and len(tlist.tokens[0].lookup_word) <= 3:
+            elif (tlist.ttype == qmod.TokenType.HOUSENUMBER
+                  and len(tlist.tokens[0].lookup_word) <= 3):
                 if any(c.isdigit() for c in tlist.tokens[0].lookup_word):
                     for repl in node.starting:
                         if repl.end == tlist.end and repl.ttype != qmod.TokenType.HOUSENUMBER:
                 if any(c.isdigit() for c in tlist.tokens[0].lookup_word):
                     for repl in node.starting:
                         if repl.end == tlist.end and repl.ttype != qmod.TokenType.HOUSENUMBER: