"""
Implementation of query analysis for the ICU tokenizer.
"""
-from typing import Tuple, Dict, List, Optional, NamedTuple, Iterator, Any, cast
+from typing import Tuple, Dict, List, Optional, Iterator, Any, cast
from collections import defaultdict
import dataclasses
import difflib
+import re
+from itertools import zip_longest
from icu import Transliterator
import sqlalchemy as sa
+from ..errors import UsageError
from ..typing import SaRow
from ..sql.sqlalchemy_types import Json
from ..connection import SearchConnection
from ..logging import log
-from ..search import query as qmod
-from ..search.query_analyzer_factory import AbstractQueryAnalyzer
+from . import query as qmod
+from ..query_preprocessing.config import QueryConfig
+from .query_analyzer_factory import AbstractQueryAnalyzer
DB_TO_TOKEN_TYPE = {
- 'W': qmod.TokenType.WORD,
- 'w': qmod.TokenType.PARTIAL,
- 'H': qmod.TokenType.HOUSENUMBER,
- 'P': qmod.TokenType.POSTCODE,
- 'C': qmod.TokenType.COUNTRY
+ 'W': qmod.TOKEN_WORD,
+ 'w': qmod.TOKEN_PARTIAL,
+ 'H': qmod.TOKEN_HOUSENUMBER,
+ 'P': qmod.TOKEN_POSTCODE,
+ 'C': qmod.TOKEN_COUNTRY
+}
+
+PENALTY_IN_TOKEN_BREAK = {
+ qmod.BREAK_START: 0.5,
+ qmod.BREAK_END: 0.5,
+ qmod.BREAK_PHRASE: 0.5,
+ qmod.BREAK_SOFT_PHRASE: 0.5,
+ qmod.BREAK_WORD: 0.1,
+ qmod.BREAK_PART: 0.0,
+ qmod.BREAK_TOKEN: 0.0
}
-class QueryPart(NamedTuple):
+@dataclasses.dataclass
+class QueryPart:
""" Normalized and transliterated form of a single term in the query.
+
When the term came out of a split during the transliteration,
the normalized string is the full word before transliteration.
- The word number keeps track of the word before transliteration
- and can be used to identify partial transliterated terms.
+ Check the subsequent break type to figure out if the word is
+ continued.
+
+ Penalty is the break penalty for the break following the token.
"""
token: str
normalized: str
- word_number: int
+ penalty: float
QueryParts = List[QueryPart]
WordDict = Dict[str, List[qmod.TokenRange]]
-def yield_words(terms: List[QueryPart], start: int) -> Iterator[Tuple[str, qmod.TokenRange]]:
- """ Return all combinations of words in the terms list after the
- given position.
+def extract_words(terms: List[QueryPart], start: int, words: WordDict) -> None:
+ """ Add all combinations of words in the terms list after the
+ given position to the word list.
"""
total = len(terms)
+ base_penalty = PENALTY_IN_TOKEN_BREAK[qmod.BREAK_WORD]
for first in range(start, total):
word = terms[first].token
- yield word, qmod.TokenRange(first, first + 1)
+ penalty = base_penalty
+ words[word].append(qmod.TokenRange(first, first + 1, penalty=penalty))
for last in range(first + 1, min(first + 20, total)):
word = ' '.join((word, terms[last].token))
- yield word, qmod.TokenRange(first, last + 1)
+ penalty += terms[last - 1].penalty
+ words[word].append(qmod.TokenRange(first, last + 1, penalty=penalty))
@dataclasses.dataclass
self.penalty += (distance/len(self.lookup_word))
@staticmethod
- def from_db_row(row: SaRow) -> 'ICUToken':
+ def from_db_row(row: SaRow, base_penalty: float = 0.0) -> 'ICUToken':
""" Create a ICUToken from the row of the word table.
"""
count = 1 if row.info is None else row.info.get('count', 1)
addr_count = 1 if row.info is None else row.info.get('addr_count', 1)
- penalty = 0.0
+ penalty = base_penalty
if row.type == 'w':
- penalty = 0.3
+ penalty += 0.3
elif row.type == 'W':
if len(row.word_token) == 1 and row.word_token == row.word:
- penalty = 0.2 if row.word.isdigit() else 0.3
+ penalty += 0.2 if row.word.isdigit() else 0.3
elif row.type == 'H':
- penalty = sum(0.1 for c in row.word_token if c != ' ' and not c.isdigit())
+ penalty += sum(0.1 for c in row.word_token if c != ' ' and not c.isdigit())
if all(not c.isdigit() for c in row.word_token):
penalty += 0.2 * (len(row.word_token) - 1)
elif row.type == 'C':
if len(row.word_token) == 1:
- penalty = 0.3
+ penalty += 0.3
if row.info is None:
lookup_word = row.word
self.transliterator = await self.conn.get_cached_value('ICUTOK', 'transliterator',
_make_transliterator)
+ await self._setup_preprocessing()
+
if 'word' not in self.conn.t.meta.tables:
sa.Table('word', self.conn.t.meta,
sa.Column('word_id', sa.Integer),
sa.Column('word', sa.Text),
sa.Column('info', Json))
+ async def _setup_preprocessing(self) -> None:
+ """ Load the rules for preprocessing and set up the handlers.
+ """
+
+ rules = self.conn.config.load_sub_configuration('icu_tokenizer.yaml',
+ config='TOKENIZER_CONFIG')
+ preprocessing_rules = rules.get('query-preprocessing', [])
+
+ self.preprocessors = []
+
+ for func in preprocessing_rules:
+ if 'step' not in func:
+ raise UsageError("Preprocessing rule is missing the 'step' attribute.")
+ if not isinstance(func['step'], str):
+ raise UsageError("'step' attribute must be a simple string.")
+
+ module = self.conn.config.load_plugin_module(
+ func['step'], 'nominatim_api.query_preprocessing')
+ self.preprocessors.append(
+ module.create(QueryConfig(func).set_normalizer(self.normalizer)))
+
async def analyze_query(self, phrases: List[qmod.Phrase]) -> qmod.QueryStruct:
""" Analyze the given list of phrases and return the
tokenized query.
"""
log().section('Analyze query (using ICU tokenizer)')
- normalized = list(filter(lambda p: p.text,
- (qmod.Phrase(p.ptype, self.normalize_text(p.text))
- for p in phrases)))
- query = qmod.QueryStruct(normalized)
+ for func in self.preprocessors:
+ phrases = func(phrases)
+
+ if len(phrases) == 1 \
+ and phrases[0].text.count(' ') > 3 \
+ and max(len(s) for s in phrases[0].text.split()) < 3:
+ normalized = []
+
+ query = qmod.QueryStruct(phrases)
+
log().var_dump('Normalized query', query.source)
if not query.source:
return query
for row in await self.lookup_in_db(list(words.keys())):
for trange in words[row.word_token]:
- token = ICUToken.from_db_row(row)
+ token = ICUToken.from_db_row(row, trange.penalty or 0.0)
if row.type == 'S':
if row.info['op'] in ('in', 'near'):
if trange.start == 0:
- query.add_token(trange, qmod.TokenType.NEAR_ITEM, token)
+ query.add_token(trange, qmod.TOKEN_NEAR_ITEM, token)
else:
if trange.start == 0 and trange.end == query.num_token_slots():
- query.add_token(trange, qmod.TokenType.NEAR_ITEM, token)
+ query.add_token(trange, qmod.TOKEN_NEAR_ITEM, token)
else:
- query.add_token(trange, qmod.TokenType.QUALIFIER, token)
+ query.add_token(trange, qmod.TOKEN_QUALIFIER, token)
else:
query.add_token(trange, DB_TO_TOKEN_TYPE[row.type], token)
standardized form search will work with. All information removed
at this stage is inevitably lost.
"""
- norm = cast(str, self.normalizer.transliterate(text))
- numspaces = norm.count(' ')
- if numspaces > 4 and len(norm) <= (numspaces + 1) * 3:
- return ''
-
- return norm
+ return cast(str, self.normalizer.transliterate(text)).strip('-: ')
def split_query(self, query: qmod.QueryStruct) -> Tuple[QueryParts, WordDict]:
""" Transliterate the phrases and split them into tokens.
"""
parts: QueryParts = []
phrase_start = 0
- words = defaultdict(list)
- wordnr = 0
+ words: WordDict = defaultdict(list)
for phrase in query.source:
query.nodes[-1].ptype = phrase.ptype
- for word in phrase.text.split(' '):
+ phrase_split = re.split('([ :-])', phrase.text)
+ # The zip construct will give us the pairs of word/break from
+ # the regular expression split. As the split array ends on the
+ # final word, we simply use the fillvalue to even out the list and
+ # add the phrase break at the end.
+ for word, breakchar in zip_longest(*[iter(phrase_split)]*2, fillvalue=','):
+ if not word:
+ continue
trans = self.transliterator.transliterate(word)
if trans:
for term in trans.split(' '):
if term:
- parts.append(QueryPart(term, word, wordnr))
- query.add_node(qmod.BreakType.TOKEN, phrase.ptype)
- query.nodes[-1].btype = qmod.BreakType.WORD
- wordnr += 1
- query.nodes[-1].btype = qmod.BreakType.PHRASE
+ parts.append(QueryPart(term, word,
+ PENALTY_IN_TOKEN_BREAK[qmod.BREAK_TOKEN]))
+ query.add_node(qmod.BREAK_TOKEN, phrase.ptype)
+ query.nodes[-1].btype = breakchar
+ parts[-1].penalty = PENALTY_IN_TOKEN_BREAK[breakchar]
- for word, wrange in yield_words(parts, phrase_start):
- words[word].append(wrange)
+ extract_words(parts, phrase_start, words)
phrase_start = len(parts)
- query.nodes[-1].btype = qmod.BreakType.END
+ query.nodes[-1].btype = qmod.BREAK_END
return parts, words
""" Add tokens to query that are not saved in the database.
"""
for part, node, i in zip(parts, query.nodes, range(1000)):
- if len(part.token) <= 4 and part[0].isdigit()\
- and not node.has_tokens(i+1, qmod.TokenType.HOUSENUMBER):
- query.add_token(qmod.TokenRange(i, i+1), qmod.TokenType.HOUSENUMBER,
+ if len(part.token) <= 4 and part.token.isdigit()\
+ and not node.has_tokens(i+1, qmod.TOKEN_HOUSENUMBER):
+ query.add_token(qmod.TokenRange(i, i+1), qmod.TOKEN_HOUSENUMBER,
ICUToken(penalty=0.5, token=0,
count=1, addr_count=1, lookup_word=part.token,
word_token=part.token, info=None))
""" Add penalties to tokens that depend on presence of other token.
"""
for i, node, tlist in query.iter_token_lists():
- if tlist.ttype == qmod.TokenType.POSTCODE:
+ if tlist.ttype == qmod.TOKEN_POSTCODE:
for repl in node.starting:
- if repl.end == tlist.end and repl.ttype != qmod.TokenType.POSTCODE \
- and (repl.ttype != qmod.TokenType.HOUSENUMBER
+ if repl.end == tlist.end and repl.ttype != qmod.TOKEN_POSTCODE \
+ and (repl.ttype != qmod.TOKEN_HOUSENUMBER
or len(tlist.tokens[0].lookup_word) > 4):
repl.add_penalty(0.39)
- elif (tlist.ttype == qmod.TokenType.HOUSENUMBER
+ elif (tlist.ttype == qmod.TOKEN_HOUSENUMBER
and len(tlist.tokens[0].lookup_word) <= 3):
if any(c.isdigit() for c in tlist.tokens[0].lookup_word):
for repl in node.starting:
- if repl.end == tlist.end and repl.ttype != qmod.TokenType.HOUSENUMBER:
+ if repl.end == tlist.end and repl.ttype != qmod.TOKEN_HOUSENUMBER:
repl.add_penalty(0.5 - tlist.tokens[0].penalty)
- elif tlist.ttype not in (qmod.TokenType.COUNTRY, qmod.TokenType.PARTIAL):
+ elif tlist.ttype not in (qmod.TOKEN_COUNTRY, qmod.TOKEN_PARTIAL):
norm = parts[i].normalized
for j in range(i + 1, tlist.end):
- if parts[j - 1].word_number != parts[j].word_number:
+ if node.btype != qmod.BREAK_TOKEN:
norm += ' ' + parts[j].normalized
for token in tlist.tokens:
cast(ICUToken, token).rematch(norm)
def _dump_transliterated(query: qmod.QueryStruct, parts: QueryParts) -> str:
- out = query.nodes[0].btype.value
+ out = query.nodes[0].btype
for node, part in zip(query.nodes[1:], parts):
- out += part.token + node.btype.value
+ out += part.token + node.btype
return out
for tlist in node.starting:
for token in tlist.tokens:
t = cast(ICUToken, token)
- yield [tlist.ttype.name, t.token, t.word_token or '',
+ yield [tlist.ttype, t.token, t.word_token or '',
t.lookup_word or '', t.penalty, t.count, t.info]