]> git.openstreetmap.org Git - nominatim.git/blobdiff - src/nominatim_api/search/db_search_builder.py
replace BreakType enum with simple char constants
[nominatim.git] / src / nominatim_api / search / db_search_builder.py
index 1ac6db2b2a96d5ebc5095d81bd936b0002194600..7e76de1408234453d2ad4137411b0feff57fb879 100644 (file)
@@ -11,7 +11,7 @@ from typing import Optional, List, Tuple, Iterator, Dict
 import heapq
 
 from ..types import SearchDetails, DataLayer
-from .query import QueryStruct, Token, TokenType, TokenRange, BreakType
+from . import query as qmod
 from .token_assignment import TokenAssignment
 from . import db_search_fields as dbf
 from . import db_searches as dbs
@@ -42,7 +42,7 @@ def build_poi_search(category: List[Tuple[str, str]],
     class _PoiData(dbf.SearchData):
         penalty = 0.0
         qualifiers = dbf.WeightedCategories(category, [0.0] * len(category))
-        countries=ccs
+        countries = ccs
 
     return dbs.PoiSearch(_PoiData())
 
@@ -51,19 +51,17 @@ class SearchBuilder:
     """ Build the abstract search queries from token assignments.
     """
 
-    def __init__(self, query: QueryStruct, details: SearchDetails) -> None:
+    def __init__(self, query: qmod.QueryStruct, details: SearchDetails) -> None:
         self.query = query
         self.details = details
 
-
     @property
     def configured_for_country(self) -> bool:
         """ Return true if the search details are configured to
             allow countries in the result.
         """
         return self.details.min_rank <= 4 and self.details.max_rank >= 4 \
-               and self.details.layer_enabled(DataLayer.ADDRESS)
-
+            and self.details.layer_enabled(DataLayer.ADDRESS)
 
     @property
     def configured_for_postcode(self) -> bool:
@@ -71,8 +69,7 @@ class SearchBuilder:
             allow postcodes in the result.
         """
         return self.details.min_rank <= 5 and self.details.max_rank >= 11\
-               and self.details.layer_enabled(DataLayer.ADDRESS)
-
+            and self.details.layer_enabled(DataLayer.ADDRESS)
 
     @property
     def configured_for_housenumbers(self) -> bool:
@@ -80,8 +77,7 @@ class SearchBuilder:
             allow addresses in the result.
         """
         return self.details.max_rank >= 30 \
-               and self.details.layer_enabled(DataLayer.ADDRESS)
-
+            and self.details.layer_enabled(DataLayer.ADDRESS)
 
     def build(self, assignment: TokenAssignment) -> Iterator[dbs.AbstractSearch]:
         """ Yield all possible abstract searches for the given token assignment.
@@ -92,7 +88,7 @@ class SearchBuilder:
 
         near_items = self.get_near_items(assignment)
         if near_items is not None and not near_items:
-            return # impossible compbination of near items and category parameter
+            return  # impossible combination of near items and category parameter
 
         if assignment.name is None:
             if near_items and not sdata.postcodes:
@@ -101,7 +97,7 @@ class SearchBuilder:
                 builder = self.build_poi_search(sdata)
             elif assignment.housenumber:
                 hnr_tokens = self.query.get_tokens(assignment.housenumber,
-                                                   TokenType.HOUSENUMBER)
+                                                   qmod.TokenType.HOUSENUMBER)
                 builder = self.build_housenumber_search(sdata, hnr_tokens, assignment.address)
             else:
                 builder = self.build_special_search(sdata, assignment.address,
@@ -123,7 +119,6 @@ class SearchBuilder:
                 search.penalty += assignment.penalty
                 yield search
 
-
     def build_poi_search(self, sdata: dbf.SearchData) -> Iterator[dbs.AbstractSearch]:
         """ Build abstract search query for a simple category search.
             This kind of search requires an additional geographic constraint.
@@ -132,9 +127,8 @@ class SearchBuilder:
            and ((self.details.viewbox and self.details.bounded_viewbox) or self.details.near):
             yield dbs.PoiSearch(sdata)
 
-
     def build_special_search(self, sdata: dbf.SearchData,
-                             address: List[TokenRange],
+                             address: List[qmod.TokenRange],
                              is_category: bool) -> Iterator[dbs.AbstractSearch]:
         """ Build abstract search queries for searches that do not involve
             a named place.
@@ -154,12 +148,10 @@ class SearchBuilder:
                                                  [t.token for r in address
                                                   for t in self.query.get_partials_list(r)],
                                                  lookups.Restrict)]
-                penalty += 0.2
             yield dbs.PostcodeSearch(penalty, sdata)
 
-
-    def build_housenumber_search(self, sdata: dbf.SearchData, hnrs: List[Token],
-                                 address: List[TokenRange]) -> Iterator[dbs.AbstractSearch]:
+    def build_housenumber_search(self, sdata: dbf.SearchData, hnrs: List[qmod.Token],
+                                 address: List[qmod.TokenRange]) -> Iterator[dbs.AbstractSearch]:
         """ Build a simple address search for special entries where the
             housenumber is the main name token.
         """
@@ -167,7 +159,7 @@ class SearchBuilder:
         expected_count = sum(t.count for t in hnrs)
 
         partials = {t.token: t.addr_count for trange in address
-                       for t in self.query.get_partials_list(trange)}
+                    for t in self.query.get_partials_list(trange)}
 
         if not partials:
             # can happen when none of the partials is indexed
@@ -181,7 +173,7 @@ class SearchBuilder:
                                                  list(partials), lookups.LookupAll))
         else:
             addr_fulls = [t.token for t
-                          in self.query.get_tokens(address[0], TokenType.WORD)]
+                          in self.query.get_tokens(address[0], qmod.TokenType.WORD)]
             if len(addr_fulls) > 5:
                 return
             sdata.lookups.append(
@@ -190,9 +182,8 @@ class SearchBuilder:
         sdata.housenumbers = dbf.WeightedStrings([], [])
         yield dbs.PlaceSearch(0.05, sdata, expected_count)
 
-
     def build_name_search(self, sdata: dbf.SearchData,
-                          name: TokenRange, address: List[TokenRange],
+                          name: qmod.TokenRange, address: List[qmod.TokenRange],
                           is_category: bool) -> Iterator[dbs.AbstractSearch]:
         """ Build abstract search queries for simple name or address searches.
         """
@@ -205,14 +196,13 @@ class SearchBuilder:
                 sdata.lookups = lookup
                 yield dbs.PlaceSearch(penalty + name_penalty, sdata, count)
 
-
-    def yield_lookups(self, name: TokenRange, address: List[TokenRange])\
-                          -> Iterator[Tuple[float, int, List[dbf.FieldLookup]]]:
+    def yield_lookups(self, name: qmod.TokenRange, address: List[qmod.TokenRange]
+                      ) -> Iterator[Tuple[float, int, List[dbf.FieldLookup]]]:
         """ Yield all variants how the given name and address should best
             be searched for. This takes into account how frequent the terms
             are and tries to find a lookup that optimizes index use.
         """
-        penalty = 0.0 # extra penalty
+        penalty = 0.0  # extra penalty
         name_partials = {t.token: t for t in self.query.get_partials_list(name)}
 
         addr_partials = [t for r in address for t in self.query.get_partials_list(r)]
@@ -226,12 +216,12 @@ class SearchBuilder:
 
         addr_count = min(t.addr_count for t in addr_partials) if addr_partials else 30000
         # Partial term to frequent. Try looking up by rare full names first.
-        name_fulls = self.query.get_tokens(name, TokenType.WORD)
+        name_fulls = self.query.get_tokens(name, qmod.TokenType.WORD)
         if name_fulls:
             fulls_count = sum(t.count for t in name_fulls)
 
             if fulls_count < 50000 or addr_count < 30000:
-                yield penalty,fulls_count / (2**len(addr_tokens)), \
+                yield penalty, fulls_count / (2**len(addr_tokens)), \
                     self.get_full_name_ranking(name_fulls, addr_partials,
                                                fulls_count > 30000 / max(1, len(addr_tokens)))
 
@@ -241,12 +231,11 @@ class SearchBuilder:
         if exp_count < 10000 and addr_count < 20000:
             penalty += 0.35 * max(1 if name_fulls else 0.1,
                                   5 - len(name_partials) - len(addr_tokens))
-            yield penalty, exp_count,\
-                  self.get_name_address_ranking(list(name_partials.keys()), addr_partials)
-
+            yield penalty, exp_count, \
+                self.get_name_address_ranking(list(name_partials.keys()), addr_partials)
 
     def get_name_address_ranking(self, name_tokens: List[int],
-                                 addr_partials: List[Token]) -> List[dbf.FieldLookup]:
+                                 addr_partials: List[qmod.Token]) -> List[dbf.FieldLookup]:
         """ Create a ranking expression looking up by name and address.
         """
         lookup = [dbf.FieldLookup('name_vector', name_tokens, lookups.LookupAll)]
@@ -268,8 +257,7 @@ class SearchBuilder:
 
         return lookup
 
-
-    def get_full_name_ranking(self, name_fulls: List[Token], addr_partials: List[Token],
+    def get_full_name_ranking(self, name_fulls: List[qmod.Token], addr_partials: List[qmod.Token],
                               use_lookup: bool) -> List[dbf.FieldLookup]:
         """ Create a ranking expression with full name terms and
             additional address lookup. When 'use_lookup' is true, then
@@ -293,12 +281,11 @@ class SearchBuilder:
         return dbf.lookup_by_any_name([t.token for t in name_fulls],
                                       addr_restrict_tokens, addr_lookup_tokens)
 
-
-    def get_name_ranking(self, trange: TokenRange,
+    def get_name_ranking(self, trange: qmod.TokenRange,
                          db_field: str = 'name_vector') -> dbf.FieldRanking:
         """ Create a ranking expression for a name term in the given range.
         """
-        name_fulls = self.query.get_tokens(trange, TokenType.WORD)
+        name_fulls = self.query.get_tokens(trange, qmod.TokenType.WORD)
         ranks = [dbf.RankedTokens(t.penalty, [t.token]) for t in name_fulls]
         ranks.sort(key=lambda r: r.penalty)
         # Fallback, sum of penalty for partials
@@ -306,8 +293,7 @@ class SearchBuilder:
         default = sum(t.penalty for t in name_partials) + 0.2
         return dbf.FieldRanking(db_field, default, ranks)
 
-
-    def get_addr_ranking(self, trange: TokenRange) -> dbf.FieldRanking:
+    def get_addr_ranking(self, trange: qmod.TokenRange) -> dbf.FieldRanking:
         """ Create a list of ranking expressions for an address term
             for the given ranges.
         """
@@ -315,13 +301,13 @@ class SearchBuilder:
         heapq.heappush(todo, (0, trange.start, dbf.RankedTokens(0.0, [])))
         ranks: List[dbf.RankedTokens] = []
 
-        while todo: # pylint: disable=too-many-nested-blocks
+        while todo:
             neglen, pos, rank = heapq.heappop(todo)
             for tlist in self.query.nodes[pos].starting:
-                if tlist.ttype in (TokenType.PARTIAL, TokenType.WORD):
+                if tlist.ttype in (qmod.TokenType.PARTIAL, qmod.TokenType.WORD):
                     if tlist.end < trange.end:
                         chgpenalty = PENALTY_WORDCHANGE[self.query.nodes[tlist.end].btype]
-                        if tlist.ttype == TokenType.PARTIAL:
+                        if tlist.ttype == qmod.TokenType.PARTIAL:
                             penalty = rank.penalty + chgpenalty \
                                       + max(t.penalty for t in tlist.tokens)
                             heapq.heappush(todo, (neglen - 1, tlist.end,
@@ -331,7 +317,7 @@ class SearchBuilder:
                                 heapq.heappush(todo, (neglen - 1, tlist.end,
                                                       rank.with_token(t, chgpenalty)))
                     elif tlist.end == trange.end:
-                        if tlist.ttype == TokenType.PARTIAL:
+                        if tlist.ttype == qmod.TokenType.PARTIAL:
                             ranks.append(dbf.RankedTokens(rank.penalty
                                                           + max(t.penalty for t in tlist.tokens),
                                                           rank.tokens))
@@ -354,7 +340,6 @@ class SearchBuilder:
 
         return dbf.FieldRanking('nameaddress_vector', default, ranks)
 
-
     def get_search_data(self, assignment: TokenAssignment) -> Optional[dbf.SearchData]:
         """ Collect the tokens for the non-name search fields in the
             assignment.
@@ -372,11 +357,11 @@ class SearchBuilder:
         if assignment.housenumber:
             sdata.set_strings('housenumbers',
                               self.query.get_tokens(assignment.housenumber,
-                                                    TokenType.HOUSENUMBER))
+                                                    qmod.TokenType.HOUSENUMBER))
         if assignment.postcode:
             sdata.set_strings('postcodes',
                               self.query.get_tokens(assignment.postcode,
-                                                    TokenType.POSTCODE))
+                                                    qmod.TokenType.POSTCODE))
         if assignment.qualifier:
             tokens = self.get_qualifier_tokens(assignment.qualifier)
             if not tokens:
@@ -401,31 +386,28 @@ class SearchBuilder:
 
         return sdata
 
-
-    def get_country_tokens(self, trange: TokenRange) -> List[Token]:
+    def get_country_tokens(self, trange: qmod.TokenRange) -> List[qmod.Token]:
         """ Return the list of country tokens for the given range,
             optionally filtered by the country list from the details
             parameters.
         """
-        tokens = self.query.get_tokens(trange, TokenType.COUNTRY)
+        tokens = self.query.get_tokens(trange, qmod.TokenType.COUNTRY)
         if self.details.countries:
             tokens = [t for t in tokens if t.lookup_word in self.details.countries]
 
         return tokens
 
-
-    def get_qualifier_tokens(self, trange: TokenRange) -> List[Token]:
+    def get_qualifier_tokens(self, trange: qmod.TokenRange) -> List[qmod.Token]:
         """ Return the list of qualifier tokens for the given range,
             optionally filtered by the qualifier list from the details
             parameters.
         """
-        tokens = self.query.get_tokens(trange, TokenType.QUALIFIER)
+        tokens = self.query.get_tokens(trange, qmod.TokenType.QUALIFIER)
         if self.details.categories:
             tokens = [t for t in tokens if t.get_category() in self.details.categories]
 
         return tokens
 
-
     def get_near_items(self, assignment: TokenAssignment) -> Optional[dbf.WeightedCategories]:
         """ Collect tokens for near items search or use the categories
             requested per parameter.
@@ -433,7 +415,7 @@ class SearchBuilder:
         """
         if assignment.near_item:
             tokens: Dict[Tuple[str, str], float] = {}
-            for t in self.query.get_tokens(assignment.near_item, TokenType.NEAR_ITEM):
+            for t in self.query.get_tokens(assignment.near_item, qmod.TokenType.NEAR_ITEM):
                 cat = t.get_category()
                 # The category of a near search will be that of near_item.
                 # Thus, if search is restricted to a category parameter,
@@ -447,10 +429,11 @@ class SearchBuilder:
 
 
 PENALTY_WORDCHANGE = {
-    BreakType.START: 0.0,
-    BreakType.END: 0.0,
-    BreakType.PHRASE: 0.0,
-    BreakType.WORD: 0.1,
-    BreakType.PART: 0.2,
-    BreakType.TOKEN: 0.4
+    qmod.BREAK_START: 0.0,
+    qmod.BREAK_END: 0.0,
+    qmod.BREAK_PHRASE: 0.0,
+    qmod.BREAK_SOFT_PHRASE: 0.0,
+    qmod.BREAK_WORD: 0.1,
+    qmod.BREAK_PART: 0.2,
+    qmod.BREAK_TOKEN: 0.4
 }