]> git.openstreetmap.org Git - nominatim.git/blobdiff - nominatim/api/search/db_search_lookups.py
prefilter bad results before adding details and reranking
[nominatim.git] / nominatim / api / search / db_search_lookups.py
index 3e307235b850b954a24a5221273fa5c930785053..aa5cef5f47e491d68fa6b69961f303fcb3b8dcb0 100644 (file)
@@ -26,18 +26,38 @@ class LookupAll(LookupType):
     inherit_cache = True
 
     def __init__(self, table: SaFromClause, column: str, tokens: List[int]) -> None:
-        super().__init__(getattr(table.c, column),
+        super().__init__(table.c.place_id, getattr(table.c, column), column,
                          sa.type_coerce(tokens, IntArray))
 
 
 @compiles(LookupAll) # type: ignore[no-untyped-call, misc]
 def _default_lookup_all(element: LookupAll,
                         compiler: 'sa.Compiled', **kw: Any) -> str:
-    col, tokens = list(element.clauses)
+    _, col, _, tokens = list(element.clauses)
     return "(%s @> %s)" % (compiler.process(col, **kw),
                            compiler.process(tokens, **kw))
 
 
+@compiles(LookupAll, 'sqlite') # type: ignore[no-untyped-call, misc]
+def _sqlite_lookup_all(element: LookupAll,
+                        compiler: 'sa.Compiled', **kw: Any) -> str:
+    place, col, colname, tokens = list(element.clauses)
+    return "(%s IN (SELECT CAST(value as bigint) FROM"\
+           " (SELECT array_intersect_fuzzy(places) as p FROM"\
+           "   (SELECT places FROM reverse_search_name"\
+           "   WHERE word IN (SELECT value FROM json_each('[' || %s || ']'))"\
+           "     AND column = %s"\
+           "   ORDER BY length(places)) as x) as u,"\
+           " json_each('[' || u.p || ']'))"\
+           " AND array_contains(%s, %s))"\
+             % (compiler.process(place, **kw),
+                compiler.process(tokens, **kw),
+                compiler.process(colname, **kw),
+                compiler.process(col, **kw),
+                compiler.process(tokens, **kw)
+                )
+
+
 
 class LookupAny(LookupType):
     """ Find all entries that contain at least one of the given tokens.
@@ -46,17 +66,28 @@ class LookupAny(LookupType):
     inherit_cache = True
 
     def __init__(self, table: SaFromClause, column: str, tokens: List[int]) -> None:
-        super().__init__(getattr(table.c, column),
+        super().__init__(table.c.place_id, getattr(table.c, column), column,
                          sa.type_coerce(tokens, IntArray))
 
-
 @compiles(LookupAny) # type: ignore[no-untyped-call, misc]
 def _default_lookup_any(element: LookupAny,
                         compiler: 'sa.Compiled', **kw: Any) -> str:
-    col, tokens = list(element.clauses)
+    _, col, _, tokens = list(element.clauses)
     return "(%s && %s)" % (compiler.process(col, **kw),
                            compiler.process(tokens, **kw))
 
+@compiles(LookupAny, 'sqlite') # type: ignore[no-untyped-call, misc]
+def _sqlite_lookup_any(element: LookupAny,
+                        compiler: 'sa.Compiled', **kw: Any) -> str:
+    place, _, colname, tokens = list(element.clauses)
+    return "%s IN (SELECT CAST(value as bigint) FROM"\
+           " (SELECT array_union(places) as p FROM reverse_search_name"\
+           "   WHERE word IN (SELECT value FROM json_each('[' || %s || ']'))"\
+           "     AND column = %s) as u,"\
+           " json_each('[' || u.p || ']'))" % (compiler.process(place, **kw),
+                                               compiler.process(tokens, **kw),
+                                               compiler.process(colname, **kw))
+
 
 
 class Restrict(LookupType):
@@ -76,3 +107,8 @@ def _default_restrict(element: Restrict,
     arg1, arg2 = list(element.clauses)
     return "(coalesce(null, %s) @> %s)" % (compiler.process(arg1, **kw),
                                            compiler.process(arg2, **kw))
+
+@compiles(Restrict, 'sqlite') # type: ignore[no-untyped-call, misc]
+def _sqlite_restrict(element: Restrict,
+                        compiler: 'sa.Compiled', **kw: Any) -> str:
+    return "array_contains(%s)" % compiler.process(element.clauses, **kw)