]> git.openstreetmap.org Git - nominatim.git/commitdiff
Merge remote-tracking branch 'upstream/master'
authorSarah Hoffmann <lonvia@denofr.de>
Fri, 25 Aug 2023 08:01:06 +0000 (10:01 +0200)
committerSarah Hoffmann <lonvia@denofr.de>
Fri, 25 Aug 2023 08:01:06 +0000 (10:01 +0200)
nominatim/api/connection.py
nominatim/api/core.py
nominatim/api/search/db_search_builder.py
nominatim/api/search/geocoder.py
nominatim/server/falcon/server.py
nominatim/server/starlette/server.py
settings/env.defaults

index bf2173144d72fa7deee39daa010f1e75ef5293dc..405213e97659d32fb9ff9d56c2478219690af6a4 100644 (file)
@@ -9,6 +9,7 @@ Extended SQLAlchemy connection class that also includes access to the schema.
 """
 from typing import cast, Any, Mapping, Sequence, Union, Dict, Optional, Set, \
                    Awaitable, Callable, TypeVar
+import asyncio
 
 import sqlalchemy as sa
 from sqlalchemy.ext.asyncio import AsyncConnection
@@ -34,6 +35,14 @@ class SearchConnection:
         self.t = tables # pylint: disable=invalid-name
         self._property_cache = properties
         self._classtables: Optional[Set[str]] = None
+        self.query_timeout: Optional[int] = None
+
+
+    def set_query_timeout(self, timeout: Optional[int]) -> None:
+        """ Set the timeout after which a query over this connection
+            is cancelled.
+        """
+        self.query_timeout = timeout
 
 
     async def scalar(self, sql: sa.sql.base.Executable,
@@ -42,7 +51,7 @@ class SearchConnection:
         """ Execute a 'scalar()' query on the connection.
         """
         log().sql(self.connection, sql, params)
-        return await self.connection.scalar(sql, params)
+        return await asyncio.wait_for(self.connection.scalar(sql, params), self.query_timeout)
 
 
     async def execute(self, sql: 'sa.Executable',
@@ -51,7 +60,7 @@ class SearchConnection:
         """ Execute a 'execute()' query on the connection.
         """
         log().sql(self.connection, sql, params)
-        return await self.connection.execute(sql, params)
+        return await asyncio.wait_for(self.connection.execute(sql, params), self.query_timeout)
 
 
     async def get_property(self, name: str, cached: bool = True) -> str:
index 1690b9f5e241576dcb35982731af28863ebd37e7..0f1dd7153af8f582410f398ab984d57ff8a18fb2 100644 (file)
@@ -36,6 +36,8 @@ class NominatimAPIAsync:
                  environ: Optional[Mapping[str, str]] = None,
                  loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
         self.config = Configuration(project_dir, environ)
+        self.query_timeout = self.config.get_int('QUERY_TIMEOUT') \
+                             if self.config.QUERY_TIMEOUT else None
         self.server_version = 0
 
         if sys.version_info >= (3, 10):
@@ -128,6 +130,7 @@ class NominatimAPIAsync:
         """
         try:
             async with self.begin() as conn:
+                conn.set_query_timeout(self.query_timeout)
                 status = await get_status(conn)
         except (PGCORE_ERROR, sa.exc.OperationalError):
             return StatusResult(700, 'Database connection failed')
@@ -142,6 +145,7 @@ class NominatimAPIAsync:
         """
         details = ntyp.LookupDetails.from_kwargs(params)
         async with self.begin() as conn:
+            conn.set_query_timeout(self.query_timeout)
             if details.keywords:
                 await make_query_analyzer(conn)
             return await get_detailed_place(conn, place, details)
@@ -154,6 +158,7 @@ class NominatimAPIAsync:
         """
         details = ntyp.LookupDetails.from_kwargs(params)
         async with self.begin() as conn:
+            conn.set_query_timeout(self.query_timeout)
             if details.keywords:
                 await make_query_analyzer(conn)
             return SearchResults(filter(None,
@@ -173,6 +178,7 @@ class NominatimAPIAsync:
 
         details = ntyp.ReverseDetails.from_kwargs(params)
         async with self.begin() as conn:
+            conn.set_query_timeout(self.query_timeout)
             if details.keywords:
                 await make_query_analyzer(conn)
             geocoder = ReverseGeocoder(conn, details)
@@ -187,7 +193,10 @@ class NominatimAPIAsync:
             raise UsageError('Nothing to search for.')
 
         async with self.begin() as conn:
-            geocoder = ForwardGeocoder(conn, ntyp.SearchDetails.from_kwargs(params))
+            conn.set_query_timeout(self.query_timeout)
+            geocoder = ForwardGeocoder(conn, ntyp.SearchDetails.from_kwargs(params),
+                                       self.config.get_int('REQUEST_TIMEOUT') \
+                                         if self.config.REQUEST_TIMEOUT else None)
             phrases = [Phrase(PhraseType.NONE, p.strip()) for p in query.split(',')]
             return await geocoder.lookup(phrases)
 
@@ -204,6 +213,7 @@ class NominatimAPIAsync:
         """ Find an address using structured search.
         """
         async with self.begin() as conn:
+            conn.set_query_timeout(self.query_timeout)
             details = ntyp.SearchDetails.from_kwargs(params)
 
             phrases: List[Phrase] = []
@@ -244,7 +254,9 @@ class NominatimAPIAsync:
                 if amenity:
                     details.layers |= ntyp.DataLayer.POI
 
-            geocoder = ForwardGeocoder(conn, details)
+            geocoder = ForwardGeocoder(conn, details,
+                                       self.config.get_int('REQUEST_TIMEOUT') \
+                                         if self.config.REQUEST_TIMEOUT else None)
             return await geocoder.lookup(phrases)
 
 
@@ -260,6 +272,7 @@ class NominatimAPIAsync:
 
         details = ntyp.SearchDetails.from_kwargs(params)
         async with self.begin() as conn:
+            conn.set_query_timeout(self.query_timeout)
             if near_query:
                 phrases = [Phrase(PhraseType.NONE, p) for p in near_query.split(',')]
             else:
@@ -267,7 +280,9 @@ class NominatimAPIAsync:
                 if details.keywords:
                     await make_query_analyzer(conn)
 
-            geocoder = ForwardGeocoder(conn, details)
+            geocoder = ForwardGeocoder(conn, details,
+                                       self.config.get_int('REQUEST_TIMEOUT') \
+                                         if self.config.REQUEST_TIMEOUT else None)
             return await geocoder.lookup_pois(categories, phrases)
 
 
index c9e48b0f3784f1bb7f6cd6cc9934b25c757c7b33..03e78d45ed3d36ddd85048aef2aa9bfd437185c5 100644 (file)
@@ -214,14 +214,14 @@ class SearchBuilder:
 
         # Partial term to frequent. Try looking up by rare full names first.
         name_fulls = self.query.get_tokens(name, TokenType.WORD)
-        fulls_count = sum(t.count for t in name_fulls) / (2**len(addr_partials))
+        fulls_count = sum(t.count for t in name_fulls)
         # At this point drop unindexed partials from the address.
         # This might yield wrong results, nothing we can do about that.
         if not partials_indexed:
             addr_tokens = [t.token for t in addr_partials if t.is_indexed]
             penalty += 1.2 * sum(t.penalty for t in addr_partials if not t.is_indexed)
         # Any of the full names applies with all of the partials from the address
-        yield penalty, fulls_count,\
+        yield penalty, fulls_count / (2**len(addr_partials)),\
               dbf.lookup_by_any_name([t.token for t in name_fulls], addr_tokens,
                                      'restrict' if fulls_count < 10000 else 'lookup_all')
 
index 564e3d8dadc1df4181c5f3a2a5c4bfb86ade0c77..f88bffbd367bb3b78042375565dc22e820816e6f 100644 (file)
@@ -9,6 +9,7 @@ Public interface to the search code.
 """
 from typing import List, Any, Optional, Iterator, Tuple
 import itertools
+import datetime as dt
 
 from nominatim.api.connection import SearchConnection
 from nominatim.api.types import SearchDetails
@@ -24,9 +25,11 @@ class ForwardGeocoder:
     """ Main class responsible for place search.
     """
 
-    def __init__(self, conn: SearchConnection, params: SearchDetails) -> None:
+    def __init__(self, conn: SearchConnection,
+                 params: SearchDetails, timeout: Optional[int]) -> None:
         self.conn = conn
         self.params = params
+        self.timeout = dt.timedelta(seconds=timeout or 1000000)
         self.query_analyzer: Optional[AbstractQueryAnalyzer] = None
 
 
@@ -71,6 +74,7 @@ class ForwardGeocoder:
         """
         log().section('Execute database searches')
         results = SearchResults()
+        end_time = dt.datetime.now() + self.timeout
 
         num_results = 0
         min_ranking = 1000.0
@@ -85,6 +89,8 @@ class ForwardGeocoder:
             log().result_dump('Results', ((r.accuracy, r) for r in results[num_results:]))
             num_results = len(results)
             prev_penalty = search.penalty
+            if dt.datetime.now() >= end_time:
+                break
 
         if results:
             min_ranking = min(r.ranking for r in results)
index e551e54256f531ddc1e101caa2ada639b09b05e4..f1030f5c82205c85b798e418f3f740343ce04951 100644 (file)
@@ -37,6 +37,17 @@ async def nominatim_error_handler(req: Request, resp: Response, #pylint: disable
     resp.content_type = exception.content_type
 
 
+async def timeout_error_handler(req: Request, resp: Response, #pylint: disable=unused-argument
+                                exception: TimeoutError, #pylint: disable=unused-argument
+                                _: Any) -> None:
+    """ Special error handler that passes message and content type as
+        per exception info.
+    """
+    resp.status = 503
+    resp.text = "Query took too long to process."
+    resp.content_type = 'text/plain; charset=utf-8'
+
+
 class ParamWrapper(api_impl.ASGIAdaptor):
     """ Adaptor class for server glue to Falcon framework.
     """
@@ -139,6 +150,7 @@ def get_application(project_dir: Path,
     app = App(cors_enable=api.config.get_bool('CORS_NOACCESSCONTROL'),
               middleware=middleware)
     app.add_error_handler(HTTPNominatimError, nominatim_error_handler)
+    app.add_error_handler(TimeoutError, timeout_error_handler)
 
     legacy_urls = api.config.get_bool('SERVE_LEGACY_URLS')
     for name, func in api_impl.ROUTES:
index 5567ac9c9b9e9986bafdfd4fdb2855c66e7e220b..19a9943c9ecce33d8184bcabe1a572bb1fe1e1b7 100644 (file)
@@ -7,14 +7,14 @@
 """
 Server implementation using the starlette webserver framework.
 """
-from typing import Any, Optional, Mapping, Callable, cast, Coroutine
+from typing import Any, Optional, Mapping, Callable, cast, Coroutine, Dict, Awaitable
 from pathlib import Path
 import datetime as dt
 
 from starlette.applications import Starlette
 from starlette.routing import Route
 from starlette.exceptions import HTTPException
-from starlette.responses import Response
+from starlette.responses import Response, PlainTextResponse
 from starlette.requests import Request
 from starlette.middleware import Middleware
 from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
@@ -110,6 +110,13 @@ class FileLoggingMiddleware(BaseHTTPMiddleware):
         return response
 
 
+async def timeout_error(request: Request, #pylint: disable=unused-argument
+                        _: Exception) -> Response:
+    """ Error handler for query timeouts.
+    """
+    return PlainTextResponse("Query took too long to process.", status_code=503)
+
+
 def get_application(project_dir: Path,
                     environ: Optional[Mapping[str, str]] = None,
                     debug: bool = True) -> Starlette:
@@ -136,10 +143,15 @@ def get_application(project_dir: Path,
     if log_file:
         middleware.append(Middleware(FileLoggingMiddleware, file_name=log_file))
 
+    exceptions: Dict[Any, Callable[[Request, Exception], Awaitable[Response]]] = {
+        TimeoutError: timeout_error
+    }
+
     async def _shutdown() -> None:
         await app.state.API.close()
 
     app = Starlette(debug=debug, routes=routes, middleware=middleware,
+                    exception_handlers=exceptions,
                     on_shutdown=[_shutdown])
 
     app.state.API = NominatimAPIAsync(project_dir, environ)
index c4739e786cdd95aa145d0c0172e949b11bb3433d..ff0a7648741021377ce3009ad3eb48882ccb9940 100644 (file)
@@ -214,6 +214,16 @@ NOMINATIM_SERVE_LEGACY_URLS=yes
 # of connections _per worker_.
 NOMINATIM_API_POOL_SIZE=10
 
+# Timeout is seconds after which a single query to the database is cancelled.
+# The user receives a 503 response, when a query times out.
+# When empty, then timeouts are disabled.
+NOMINATIM_QUERY_TIMEOUT=10
+
+# Maximum time a single request is allowed to take. When the timeout is
+# exceeeded, the available results are returned.
+# When empty, then timouts are disabled.
+NOMINATIM_REQUEST_TIMEOUT=60
+
 # Search elements just within countries
 # If, despite not finding a point within the static grid of countries, it
 # finds a geometry of a region, do not return the geometry. Return "Unable