]> git.openstreetmap.org Git - nominatim.git/commitdiff
Merge remote-tracking branch 'upstream/master'
authorSarah Hoffmann <lonvia@denofr.de>
Mon, 27 Nov 2023 11:02:33 +0000 (12:02 +0100)
committerSarah Hoffmann <lonvia@denofr.de>
Mon, 27 Nov 2023 11:02:33 +0000 (12:02 +0100)
nominatim/api/search/db_search_builder.py
nominatim/api/search/db_search_fields.py
nominatim/api/search/db_searches.py
nominatim/api/search/geocoder.py
nominatim/api/search/query.py
test/python/api/search/test_api_search_query.py
test/python/api/search/test_db_search_builder.py
test/python/api/search/test_token_assignment.py

index 905b5c621cfe8cd99e28ec0bf13a2a5c9bff392a..7826925aed6ce77271e92bbef4612a3b1e5357bd 100644 (file)
@@ -7,7 +7,7 @@
 """
 Convertion from token assignment to an abstract DB search.
 """
-from typing import Optional, List, Tuple, Iterator
+from typing import Optional, List, Tuple, Iterator, Dict
 import heapq
 
 from nominatim.api.types import SearchDetails, DataLayer
@@ -339,12 +339,13 @@ class SearchBuilder:
             Returns None if no category search is requested.
         """
         if assignment.category:
-            tokens = [t for t in self.query.get_tokens(assignment.category,
-                                                       TokenType.CATEGORY)
-                      if not self.details.categories
-                         or t.get_category() in self.details.categories]
-            return dbf.WeightedCategories([t.get_category() for t in tokens],
-                                          [t.penalty for t in tokens])
+            tokens: Dict[Tuple[str, str], float] = {}
+            for t in self.query.get_tokens(assignment.category, TokenType.CATEGORY):
+                cat = t.get_category()
+                if (not self.details.categories or cat in self.details.categories)\
+                   and t.penalty < tokens.get(cat, 1000.0):
+                    tokens[cat] = t.penalty
+            return dbf.WeightedCategories(list(tokens.keys()), list(tokens.values()))
 
         if self.details.categories:
             return dbf.WeightedCategories(self.details.categories,
index 612e90597df2064a4ba6bf19221076093c5c55f7..59af826086db86027f2c808dee51824fb17e72ff 100644 (file)
@@ -7,7 +7,7 @@
 """
 Data structures for more complex fields in abstract search descriptions.
 """
-from typing import List, Tuple, Iterator, cast
+from typing import List, Tuple, Iterator, cast, Dict
 import dataclasses
 
 import sqlalchemy as sa
@@ -195,10 +195,17 @@ class SearchData:
         """ Set the qulaifier field from the given tokens.
         """
         if tokens:
-            min_penalty = min(t.penalty for t in tokens)
+            categories: Dict[Tuple[str, str], float] = {}
+            min_penalty = 1000.0
+            for t in tokens:
+                if t.penalty < min_penalty:
+                    min_penalty = t.penalty
+                cat = t.get_category()
+                if t.penalty < categories.get(cat, 1000.0):
+                    categories[cat] = t.penalty
             self.penalty += min_penalty
-            self.qualifiers = WeightedCategories([t.get_category() for t in tokens],
-                                                 [t.penalty - min_penalty for t in tokens])
+            self.qualifiers = WeightedCategories(list(categories.keys()),
+                                                 list(categories.values()))
 
 
     def set_ranking(self, rankings: List[FieldRanking]) -> None:
index 63da4c5d57a1d2e1ad2c4e2af2d0b6622b990e75..ce5fbc6341cd7d68efebe2100dc858a9ef6a2ffc 100644 (file)
@@ -66,7 +66,7 @@ def _select_placex(t: SaFromClause) -> SaSelect:
                      t.c.class_, t.c.type,
                      t.c.address, t.c.extratags,
                      t.c.housenumber, t.c.postcode, t.c.country_code,
-                     t.c.importance, t.c.wikipedia,
+                     t.c.wikipedia,
                      t.c.parent_place_id, t.c.rank_address, t.c.rank_search,
                      t.c.linked_place_id, t.c.admin_level,
                      t.c.centroid,
@@ -158,7 +158,8 @@ async def _get_placex_housenumbers(conn: SearchConnection,
                                    place_ids: List[int],
                                    details: SearchDetails) -> AsyncIterator[nres.SearchResult]:
     t = conn.t.placex
-    sql = _select_placex(t).where(t.c.place_id.in_(place_ids))
+    sql = _select_placex(t).add_columns(t.c.importance)\
+                           .where(t.c.place_id.in_(place_ids))
 
     if details.geometry_output:
         sql = _add_geometry_columns(sql, t.c.geometry, details)
@@ -255,9 +256,20 @@ class NearSearch(AbstractSearch):
 
         base.sort(key=lambda r: (r.accuracy, r.rank_search))
         max_accuracy = base[0].accuracy + 0.5
+        if base[0].rank_address == 0:
+            min_rank = 0
+            max_rank = 0
+        elif base[0].rank_address < 26:
+            min_rank = 1
+            max_rank = min(25, base[0].rank_address + 4)
+        else:
+            min_rank = 26
+            max_rank = 30
         base = nres.SearchResults(r for r in base if r.source_table == nres.SourceTable.PLACEX
                                                      and r.accuracy <= max_accuracy
-                                                     and r.bbox and r.bbox.area < 20)
+                                                     and r.bbox and r.bbox.area < 20
+                                                     and r.rank_address >= min_rank
+                                                     and r.rank_address <= max_rank)
 
         if base:
             baseids = [b.place_id for b in base[:5] if b.place_id]
@@ -279,28 +291,37 @@ class NearSearch(AbstractSearch):
         """
         table = await conn.get_class_table(*category)
 
-        t = conn.t.placex
         tgeom = conn.t.placex.alias('pgeom')
 
-        sql = _select_placex(t).where(tgeom.c.place_id.in_(ids))\
-                               .where(t.c.class_ == category[0])\
-                               .where(t.c.type == category[1])
-
         if table is None:
             # No classtype table available, do a simplified lookup in placex.
-            sql = sql.join(tgeom, t.c.geometry.ST_DWithin(tgeom.c.centroid, 0.01))\
-                     .order_by(tgeom.c.centroid.ST_Distance(t.c.centroid))
+            table = conn.t.placex.alias('inner')
+            sql = sa.select(table.c.place_id,
+                            sa.func.min(tgeom.c.centroid.ST_Distance(table.c.centroid))
+                              .label('dist'))\
+                    .join(tgeom, table.c.geometry.intersects(tgeom.c.centroid.ST_Expand(0.01)))\
+                    .where(table.c.class_ == category[0])\
+                    .where(table.c.type == category[1])
         else:
             # Use classtype table. We can afford to use a larger
             # radius for the lookup.
-            sql = sql.join(table, t.c.place_id == table.c.place_id)\
-                     .join(tgeom,
-                           table.c.centroid.ST_CoveredBy(
-                               sa.case((sa.and_(tgeom.c.rank_address > 9,
+            sql = sa.select(table.c.place_id,
+                            sa.func.min(tgeom.c.centroid.ST_Distance(table.c.centroid))
+                              .label('dist'))\
+                    .join(tgeom,
+                          table.c.centroid.ST_CoveredBy(
+                              sa.case((sa.and_(tgeom.c.rank_address > 9,
                                                 tgeom.c.geometry.is_area()),
-                                        tgeom.c.geometry),
-                                       else_ = tgeom.c.centroid.ST_Expand(0.05))))\
-                     .order_by(tgeom.c.centroid.ST_Distance(table.c.centroid))
+                                       tgeom.c.geometry),
+                                      else_ = tgeom.c.centroid.ST_Expand(0.05))))
+
+        inner = sql.where(tgeom.c.place_id.in_(ids))\
+                   .group_by(table.c.place_id).subquery()
+
+        t = conn.t.placex
+        sql = _select_placex(t).add_columns((-inner.c.dist).label('importance'))\
+                               .join(inner, inner.c.place_id == t.c.place_id)\
+                               .order_by(inner.c.dist)
 
         sql = sql.where(no_index(t.c.rank_address).between(MIN_RANK_PARAM, MAX_RANK_PARAM))
         if details.countries:
@@ -342,6 +363,8 @@ class PoiSearch(AbstractSearch):
             # simply search in placex table
             def _base_query() -> SaSelect:
                 return _select_placex(t) \
+                           .add_columns((-t.c.centroid.ST_Distance(NEAR_PARAM))
+                                         .label('importance'))\
                            .where(t.c.linked_place_id == None) \
                            .where(t.c.geometry.ST_DWithin(NEAR_PARAM, NEAR_RADIUS_PARAM)) \
                            .order_by(t.c.centroid.ST_Distance(NEAR_PARAM)) \
@@ -370,6 +393,7 @@ class PoiSearch(AbstractSearch):
                 table = await conn.get_class_table(*category)
                 if table is not None:
                     sql = _select_placex(t)\
+                               .add_columns(t.c.importance)\
                                .join(table, t.c.place_id == table.c.place_id)\
                                .where(t.c.class_ == category[0])\
                                .where(t.c.type == category[1])
@@ -415,6 +439,7 @@ class CountrySearch(AbstractSearch):
 
         ccodes = self.countries.values
         sql = _select_placex(t)\
+                .add_columns(t.c.importance)\
                 .where(t.c.country_code.in_(ccodes))\
                 .where(t.c.rank_address == 4)
 
@@ -591,15 +616,7 @@ class PlaceSearch(AbstractSearch):
         tsearch = conn.t.search_name
 
         sql: SaLambdaSelect = sa.lambda_stmt(lambda:
-                  sa.select(t.c.place_id, t.c.osm_type, t.c.osm_id, t.c.name,
-                            t.c.class_, t.c.type,
-                            t.c.address, t.c.extratags, t.c.admin_level,
-                            t.c.housenumber, t.c.postcode, t.c.country_code,
-                            t.c.wikipedia,
-                            t.c.parent_place_id, t.c.rank_address, t.c.rank_search,
-                            t.c.centroid,
-                            t.c.geometry.ST_Expand(0).label('bbox'))
-                   .where(t.c.place_id == tsearch.c.place_id))
+                  _select_placex(t).where(t.c.place_id == tsearch.c.place_id))
 
 
         if details.geometry_output:
index 7ff3ed08af9dcfa23c8072d4d57152a253930007..91c45b65a76e977f93cc770f6303557211af470d 100644 (file)
@@ -134,7 +134,10 @@ class ForwardGeocoder:
             return
 
         for result in results:
-            if not result.display_name:
+            # Negative importance indicates ordering by distance, which is
+            # more important than word matching.
+            if not result.display_name\
+               or (result.importance is not None and result.importance < 0):
                 continue
             distance = 0.0
             norm = self.query_analyzer.normalize_text(result.display_name)
index 5d75eb0fbe98c492638bbb174b5db930490f6788..4bf009a53a7add87b44cee2ac2508b72e1846f2b 100644 (file)
@@ -70,14 +70,16 @@ class PhraseType(enum.Enum):
     COUNTRY = enum.auto()
     """ Contains the country name or code. """
 
-    def compatible_with(self, ttype: TokenType) -> bool:
+    def compatible_with(self, ttype: TokenType,
+                        is_full_phrase: bool) -> bool:
         """ Check if the given token type can be used with the phrase type.
         """
         if self == PhraseType.NONE:
-            return True
+            return not is_full_phrase or ttype != TokenType.QUALIFIER
         if self == PhraseType.AMENITY:
-            return ttype in (TokenType.WORD, TokenType.PARTIAL,
-                             TokenType.QUALIFIER, TokenType.CATEGORY)
+            return ttype in (TokenType.WORD, TokenType.PARTIAL)\
+                   or (is_full_phrase and ttype == TokenType.CATEGORY)\
+                   or (not is_full_phrase and ttype == TokenType.QUALIFIER)
         if self == PhraseType.STREET:
             return ttype in (TokenType.WORD, TokenType.PARTIAL, TokenType.HOUSENUMBER)
         if self == PhraseType.POSTCODE:
@@ -244,7 +246,9 @@ class QueryStruct:
             be added to, then the token is silently dropped.
         """
         snode = self.nodes[trange.start]
-        if snode.ptype.compatible_with(ttype):
+        full_phrase = snode.btype in (BreakType.START, BreakType.PHRASE)\
+                      and self.nodes[trange.end].btype in (BreakType.PHRASE, BreakType.END)
+        if snode.ptype.compatible_with(ttype, full_phrase):
             tlist = snode.get_tokens(trange.end, ttype)
             if tlist is None:
                 snode.starting.append(TokenList(trange.end, ttype, [token]))
index f8c9c2dc865ba9f8ca527014c1d292dfbba14313..69a17412cf14170cd0a2c6a9209dab73676dc5b1 100644 (file)
@@ -28,12 +28,12 @@ def mktoken(tid: int):
                                          ('COUNTRY', 'COUNTRY'),
                                          ('POSTCODE', 'POSTCODE')])
 def test_phrase_compatible(ptype, ttype):
-    assert query.PhraseType[ptype].compatible_with(query.TokenType[ttype])
+    assert query.PhraseType[ptype].compatible_with(query.TokenType[ttype], False)
 
 
 @pytest.mark.parametrize('ptype', ['COUNTRY', 'POSTCODE'])
 def test_phrase_incompatible(ptype):
-    assert not query.PhraseType[ptype].compatible_with(query.TokenType.PARTIAL)
+    assert not query.PhraseType[ptype].compatible_with(query.TokenType.PARTIAL, True)
 
 
 def test_query_node_empty():
@@ -99,3 +99,36 @@ def test_query_struct_incompatible_token():
 
     assert q.get_tokens(query.TokenRange(0, 1), query.TokenType.PARTIAL) == []
     assert len(q.get_tokens(query.TokenRange(1, 2), query.TokenType.COUNTRY)) == 1
+
+
+def test_query_struct_amenity_single_word():
+    q = query.QueryStruct([query.Phrase(query.PhraseType.AMENITY, 'bar')])
+    q.add_node(query.BreakType.END, query.PhraseType.NONE)
+
+    q.add_token(query.TokenRange(0, 1), query.TokenType.PARTIAL, mktoken(1))
+    q.add_token(query.TokenRange(0, 1), query.TokenType.CATEGORY, mktoken(2))
+    q.add_token(query.TokenRange(0, 1), query.TokenType.QUALIFIER, mktoken(3))
+
+    assert len(q.get_tokens(query.TokenRange(0, 1), query.TokenType.PARTIAL)) == 1
+    assert len(q.get_tokens(query.TokenRange(0, 1), query.TokenType.CATEGORY)) == 1
+    assert len(q.get_tokens(query.TokenRange(0, 1), query.TokenType.QUALIFIER)) == 0
+
+
+def test_query_struct_amenity_two_words():
+    q = query.QueryStruct([query.Phrase(query.PhraseType.AMENITY, 'foo bar')])
+    q.add_node(query.BreakType.WORD, query.PhraseType.AMENITY)
+    q.add_node(query.BreakType.END, query.PhraseType.NONE)
+
+    for trange in [(0, 1), (1, 2)]:
+        q.add_token(query.TokenRange(*trange), query.TokenType.PARTIAL, mktoken(1))
+        q.add_token(query.TokenRange(*trange), query.TokenType.CATEGORY, mktoken(2))
+        q.add_token(query.TokenRange(*trange), query.TokenType.QUALIFIER, mktoken(3))
+
+    assert len(q.get_tokens(query.TokenRange(0, 1), query.TokenType.PARTIAL)) == 1
+    assert len(q.get_tokens(query.TokenRange(0, 1), query.TokenType.CATEGORY)) == 0
+    assert len(q.get_tokens(query.TokenRange(0, 1), query.TokenType.QUALIFIER)) == 1
+
+    assert len(q.get_tokens(query.TokenRange(1, 2), query.TokenType.PARTIAL)) == 1
+    assert len(q.get_tokens(query.TokenRange(1, 2), query.TokenType.CATEGORY)) == 0
+    assert len(q.get_tokens(query.TokenRange(1, 2), query.TokenType.QUALIFIER)) == 1
+
index c93b8ead3c2fda0a49320726d72bda6c4282bbb1..c10a6c77f2917828b1ca4007f36789db5881b1e6 100644 (file)
@@ -21,21 +21,18 @@ class MyToken(Token):
 
 
 def make_query(*args):
-    q = None
+    q = QueryStruct([Phrase(PhraseType.NONE, '')])
 
-    for tlist in args:
-        if q is None:
-            q = QueryStruct([Phrase(PhraseType.NONE, '')])
-        else:
-            q.add_node(BreakType.WORD, PhraseType.NONE)
+    for _ in range(max(inner[0] for tlist in args for inner in tlist)):
+        q.add_node(BreakType.WORD, PhraseType.NONE)
+    q.add_node(BreakType.END, PhraseType.NONE)
 
-        start = len(q.nodes) - 1
+    for start, tlist in enumerate(args):
         for end, ttype, tinfo in tlist:
             for tid, word in tinfo:
                 q.add_token(TokenRange(start, end), ttype,
                             MyToken(0.5 if ttype == TokenType.PARTIAL else 0.0, tid, 1, word, True))
 
-    q.add_node(BreakType.END, PhraseType.NONE)
 
     return q
 
index dc123403ab24185aa78e59d842cecb0bce48e296..6dc25b1e7507159dc2163d03ac13451d817559fe 100644 (file)
@@ -18,21 +18,17 @@ class MyToken(Token):
 
 
 def make_query(*args):
-    q = None
+    q = QueryStruct([Phrase(args[0][1], '')])
     dummy = MyToken(3.0, 45, 1, 'foo', True)
 
-    for btype, ptype, tlist in args:
-        if q is None:
-            q = QueryStruct([Phrase(ptype, '')])
-        else:
-            q.add_node(btype, ptype)
+    for btype, ptype, _ in args[1:]:
+        q.add_node(btype, ptype)
+    q.add_node(BreakType.END, PhraseType.NONE)
 
-        start = len(q.nodes) - 1
-        for end, ttype in tlist:
+    for start, t in enumerate(args):
+        for end, ttype in t[2]:
             q.add_token(TokenRange(start, end), ttype, dummy)
 
-    q.add_node(BreakType.END, PhraseType.NONE)
-
     return q