]> git.openstreetmap.org Git - nominatim.git/blobdiff - src/nominatim_api/core.py
Merge remote-tracking branch 'upstream/master'
[nominatim.git] / src / nominatim_api / core.py
index 6c4c37d7e824b681e91f3de6ace74968cc440813..3cf9e989141b5f118fe8d287805ad754124d42df 100644 (file)
@@ -7,7 +7,8 @@
 """
 Implementation of classes for API access via libraries.
 """
 """
 Implementation of classes for API access via libraries.
 """
-from typing import Mapping, Optional, Any, AsyncIterator, Dict, Sequence, List, Tuple, cast
+from typing import Mapping, Optional, Any, AsyncIterator, Dict, Sequence, List, \
+                   Union, Tuple, cast
 import asyncio
 import sys
 import contextlib
 import asyncio
 import sys
 import contextlib
@@ -20,17 +21,17 @@ from .errors import UsageError
 from .sql.sqlalchemy_schema import SearchTables
 from .sql.async_core_library import PGCORE_LIB, PGCORE_ERROR
 from .config import Configuration
 from .sql.sqlalchemy_schema import SearchTables
 from .sql.async_core_library import PGCORE_LIB, PGCORE_ERROR
 from .config import Configuration
-from .sql import sqlite_functions, sqlalchemy_functions #pylint: disable=unused-import
+from .sql import sqlite_functions, sqlalchemy_functions  # noqa
 from .connection import SearchConnection
 from .status import get_status, StatusResult
 from .connection import SearchConnection
 from .status import get_status, StatusResult
-from .lookup import get_detailed_place, get_simple_place
+from .lookup import get_places, get_detailed_place
 from .reverse import ReverseGeocoder
 from .search import ForwardGeocoder, Phrase, PhraseType, make_query_analyzer
 from . import types as ntyp
 from .results import DetailedResult, ReverseResult, SearchResults
 
 
 from .reverse import ReverseGeocoder
 from .search import ForwardGeocoder, Phrase, PhraseType, make_query_analyzer
 from . import types as ntyp
 from .results import DetailedResult, ReverseResult, SearchResults
 
 
-class NominatimAPIAsync: #pylint: disable=too-many-instance-attributes
+class NominatimAPIAsync:
     """ The main frontend to the Nominatim database implements the
         functions for lookup, forward and reverse geocoding using
         asynchronous functions.
     """ The main frontend to the Nominatim database implements the
         functions for lookup, forward and reverse geocoding using
         asynchronous functions.
@@ -38,8 +39,10 @@ class NominatimAPIAsync: #pylint: disable=too-many-instance-attributes
         This class shares most of the functions with its synchronous
         version. There are some additional functions or parameters,
         which are documented below.
         This class shares most of the functions with its synchronous
         version. There are some additional functions or parameters,
         which are documented below.
+
+        This class should usually be used as a context manager in 'with' context.
     """
     """
-    def __init__(self, project_dir: Path,
+    def __init__(self, project_dir: Optional[Union[str, Path]] = None,
                  environ: Optional[Mapping[str, str]] = None,
                  loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
         """ Initiate a new frontend object with synchronous API functions.
                  environ: Optional[Mapping[str, str]] = None,
                  loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
         """ Initiate a new frontend object with synchronous API functions.
@@ -58,19 +61,18 @@ class NominatimAPIAsync: #pylint: disable=too-many-instance-attributes
         """
         self.config = Configuration(project_dir, environ)
         self.query_timeout = self.config.get_int('QUERY_TIMEOUT') \
         """
         self.config = Configuration(project_dir, environ)
         self.query_timeout = self.config.get_int('QUERY_TIMEOUT') \
-                             if self.config.QUERY_TIMEOUT else None
+            if self.config.QUERY_TIMEOUT else None
         self.reverse_restrict_to_country_area = self.config.get_bool('SEARCH_WITHIN_COUNTRIES')
         self.server_version = 0
 
         if sys.version_info >= (3, 10):
             self._engine_lock = asyncio.Lock()
         else:
         self.reverse_restrict_to_country_area = self.config.get_bool('SEARCH_WITHIN_COUNTRIES')
         self.server_version = 0
 
         if sys.version_info >= (3, 10):
             self._engine_lock = asyncio.Lock()
         else:
-            self._engine_lock = asyncio.Lock(loop=loop) # pylint: disable=unexpected-keyword-arg
+            self._engine_lock = asyncio.Lock(loop=loop)
         self._engine: Optional[sa_asyncio.AsyncEngine] = None
         self._tables: Optional[SearchTables] = None
         self._property_cache: Dict[str, Any] = {'DB:server_version': 0}
 
         self._engine: Optional[sa_asyncio.AsyncEngine] = None
         self._tables: Optional[SearchTables] = None
         self._property_cache: Dict[str, Any] = {'DB:server_version': 0}
 
-
     async def setup_database(self) -> None:
         """ Set up the SQL engine and connections.
 
     async def setup_database(self) -> None:
         """ Set up the SQL engine and connections.
 
@@ -92,7 +94,6 @@ class NominatimAPIAsync: #pylint: disable=too-many-instance-attributes
                 extra_args['max_overflow'] = 0
                 extra_args['pool_size'] = self.config.get_int('API_POOL_SIZE')
 
                 extra_args['max_overflow'] = 0
                 extra_args['pool_size'] = self.config.get_int('API_POOL_SIZE')
 
-
             is_sqlite = self.config.DATABASE_DSN.startswith('sqlite:')
 
             if is_sqlite:
             is_sqlite = self.config.DATABASE_DSN.startswith('sqlite:')
 
             if is_sqlite:
@@ -137,26 +138,23 @@ class NominatimAPIAsync: #pylint: disable=too-many-instance-attributes
                     async with engine.begin() as conn:
                         result = await conn.scalar(sa.text('SHOW server_version_num'))
                         server_version = int(result)
                     async with engine.begin() as conn:
                         result = await conn.scalar(sa.text('SHOW server_version_num'))
                         server_version = int(result)
-                        if server_version >= 110000:
-                            await conn.execute(sa.text("SET jit_above_cost TO '-1'"))
-                            await conn.execute(sa.text(
-                                    "SET max_parallel_workers_per_gather TO '0'"))
+                        await conn.execute(sa.text("SET jit_above_cost TO '-1'"))
+                        await conn.execute(sa.text(
+                                "SET max_parallel_workers_per_gather TO '0'"))
                 except (PGCORE_ERROR, sa.exc.OperationalError):
                     server_version = 0
 
                 except (PGCORE_ERROR, sa.exc.OperationalError):
                     server_version = 0
 
-                if server_version >= 110000:
-                    @sa.event.listens_for(engine.sync_engine, "connect")
-                    def _on_connect(dbapi_con: Any, _: Any) -> None:
-                        cursor = dbapi_con.cursor()
-                        cursor.execute("SET jit_above_cost TO '-1'")
-                        cursor.execute("SET max_parallel_workers_per_gather TO '0'")
+                @sa.event.listens_for(engine.sync_engine, "connect")
+                def _on_connect(dbapi_con: Any, _: Any) -> None:
+                    cursor = dbapi_con.cursor()
+                    cursor.execute("SET jit_above_cost TO '-1'")
+                    cursor.execute("SET max_parallel_workers_per_gather TO '0'")
 
             self._property_cache['DB:server_version'] = server_version
 
 
             self._property_cache['DB:server_version'] = server_version
 
-            self._tables = SearchTables(sa.MetaData()) # pylint: disable=no-member
+            self._tables = SearchTables(sa.MetaData())
             self._engine = engine
 
             self._engine = engine
 
-
     async def close(self) -> None:
         """ Close all active connections to the database. The NominatimAPIAsync
             object remains usable after closing. If a new API functions is
     async def close(self) -> None:
         """ Close all active connections to the database. The NominatimAPIAsync
             object remains usable after closing. If a new API functions is
@@ -165,6 +163,11 @@ class NominatimAPIAsync: #pylint: disable=too-many-instance-attributes
         if self._engine is not None:
             await self._engine.dispose()
 
         if self._engine is not None:
             await self._engine.dispose()
 
+    async def __aenter__(self) -> 'NominatimAPIAsync':
+        return self
+
+    async def __aexit__(self, *_: Any) -> None:
+        await self.close()
 
     @contextlib.asynccontextmanager
     async def begin(self) -> AsyncIterator[SearchConnection]:
 
     @contextlib.asynccontextmanager
     async def begin(self) -> AsyncIterator[SearchConnection]:
@@ -183,7 +186,6 @@ class NominatimAPIAsync: #pylint: disable=too-many-instance-attributes
         async with self._engine.begin() as conn:
             yield SearchConnection(conn, self._tables, self._property_cache)
 
         async with self._engine.begin() as conn:
             yield SearchConnection(conn, self._tables, self._property_cache)
 
-
     async def status(self) -> StatusResult:
         """ Return the status of the database.
         """
     async def status(self) -> StatusResult:
         """ Return the status of the database.
         """
@@ -196,7 +198,6 @@ class NominatimAPIAsync: #pylint: disable=too-many-instance-attributes
 
         return status
 
 
         return status
 
-
     async def details(self, place: ntyp.PlaceRef, **params: Any) -> Optional[DetailedResult]:
         """ Get detailed information about a place in the database.
 
     async def details(self, place: ntyp.PlaceRef, **params: Any) -> Optional[DetailedResult]:
         """ Get detailed information about a place in the database.
 
@@ -209,7 +210,6 @@ class NominatimAPIAsync: #pylint: disable=too-many-instance-attributes
                 await make_query_analyzer(conn)
             return await get_detailed_place(conn, place, details)
 
                 await make_query_analyzer(conn)
             return await get_detailed_place(conn, place, details)
 
-
     async def lookup(self, places: Sequence[ntyp.PlaceRef], **params: Any) -> SearchResults:
         """ Get simple information about a list of places.
 
     async def lookup(self, places: Sequence[ntyp.PlaceRef], **params: Any) -> SearchResults:
         """ Get simple information about a list of places.
 
@@ -220,9 +220,7 @@ class NominatimAPIAsync: #pylint: disable=too-many-instance-attributes
             conn.set_query_timeout(self.query_timeout)
             if details.keywords:
                 await make_query_analyzer(conn)
             conn.set_query_timeout(self.query_timeout)
             if details.keywords:
                 await make_query_analyzer(conn)
-            return SearchResults(filter(None,
-                                        [await get_simple_place(conn, p, details) for p in places]))
-
+            return await get_places(conn, places, details)
 
     async def reverse(self, coord: ntyp.AnyPoint, **params: Any) -> Optional[ReverseResult]:
         """ Find a place by its coordinates. Also known as reverse geocoding.
 
     async def reverse(self, coord: ntyp.AnyPoint, **params: Any) -> Optional[ReverseResult]:
         """ Find a place by its coordinates. Also known as reverse geocoding.
@@ -244,7 +242,6 @@ class NominatimAPIAsync: #pylint: disable=too-many-instance-attributes
                                        self.reverse_restrict_to_country_area)
             return await geocoder.lookup(coord)
 
                                        self.reverse_restrict_to_country_area)
             return await geocoder.lookup(coord)
 
-
     async def search(self, query: str, **params: Any) -> SearchResults:
         """ Find a place by free-text search. Also known as forward geocoding.
         """
     async def search(self, query: str, **params: Any) -> SearchResults:
         """ Find a place by free-text search. Also known as forward geocoding.
         """
@@ -255,13 +252,11 @@ class NominatimAPIAsync: #pylint: disable=too-many-instance-attributes
         async with self.begin() as conn:
             conn.set_query_timeout(self.query_timeout)
             geocoder = ForwardGeocoder(conn, ntyp.SearchDetails.from_kwargs(params),
         async with self.begin() as conn:
             conn.set_query_timeout(self.query_timeout)
             geocoder = ForwardGeocoder(conn, ntyp.SearchDetails.from_kwargs(params),
-                                       self.config.get_int('REQUEST_TIMEOUT') \
-                                         if self.config.REQUEST_TIMEOUT else None)
+                                       self.config.get_int('REQUEST_TIMEOUT')
+                                       if self.config.REQUEST_TIMEOUT else None)
             phrases = [Phrase(PhraseType.NONE, p.strip()) for p in query.split(',')]
             return await geocoder.lookup(phrases)
 
             phrases = [Phrase(PhraseType.NONE, p.strip()) for p in query.split(',')]
             return await geocoder.lookup(phrases)
 
-
-    # pylint: disable=too-many-arguments,too-many-branches
     async def search_address(self, amenity: Optional[str] = None,
                              street: Optional[str] = None,
                              city: Optional[str] = None,
     async def search_address(self, amenity: Optional[str] = None,
                              street: Optional[str] = None,
                              city: Optional[str] = None,
@@ -315,11 +310,10 @@ class NominatimAPIAsync: #pylint: disable=too-many-instance-attributes
                     details.layers |= ntyp.DataLayer.POI
 
             geocoder = ForwardGeocoder(conn, details,
                     details.layers |= ntyp.DataLayer.POI
 
             geocoder = ForwardGeocoder(conn, details,
-                                       self.config.get_int('REQUEST_TIMEOUT') \
-                                         if self.config.REQUEST_TIMEOUT else None)
+                                       self.config.get_int('REQUEST_TIMEOUT')
+                                       if self.config.REQUEST_TIMEOUT else None)
             return await geocoder.lookup(phrases)
 
             return await geocoder.lookup(phrases)
 
-
     async def search_category(self, categories: List[Tuple[str, str]],
                               near_query: Optional[str] = None,
                               **params: Any) -> SearchResults:
     async def search_category(self, categories: List[Tuple[str, str]],
                               near_query: Optional[str] = None,
                               **params: Any) -> SearchResults:
@@ -341,19 +335,20 @@ class NominatimAPIAsync: #pylint: disable=too-many-instance-attributes
                     await make_query_analyzer(conn)
 
             geocoder = ForwardGeocoder(conn, details,
                     await make_query_analyzer(conn)
 
             geocoder = ForwardGeocoder(conn, details,
-                                       self.config.get_int('REQUEST_TIMEOUT') \
-                                         if self.config.REQUEST_TIMEOUT else None)
+                                       self.config.get_int('REQUEST_TIMEOUT')
+                                       if self.config.REQUEST_TIMEOUT else None)
             return await geocoder.lookup_pois(categories, phrases)
 
 
             return await geocoder.lookup_pois(categories, phrases)
 
 
-
 class NominatimAPI:
     """ This class provides a thin synchronous wrapper around the asynchronous
         Nominatim functions. It creates its own event loop and runs each
         synchronous function call to completion using that loop.
 class NominatimAPI:
     """ This class provides a thin synchronous wrapper around the asynchronous
         Nominatim functions. It creates its own event loop and runs each
         synchronous function call to completion using that loop.
+
+        This class should usually be used as a context manager in 'with' context.
     """
 
     """
 
-    def __init__(self, project_dir: Path,
+    def __init__(self, project_dir: Optional[Union[str, Path]] = None,
                  environ: Optional[Mapping[str, str]] = None) -> None:
         """ Initiate a new frontend object with synchronous API functions.
 
                  environ: Optional[Mapping[str, str]] = None) -> None:
         """ Initiate a new frontend object with synchronous API functions.
 
@@ -369,16 +364,21 @@ class NominatimAPI:
         self._loop = asyncio.new_event_loop()
         self._async_api = NominatimAPIAsync(project_dir, environ, loop=self._loop)
 
         self._loop = asyncio.new_event_loop()
         self._async_api = NominatimAPIAsync(project_dir, environ, loop=self._loop)
 
-
     def close(self) -> None:
         """ Close all active connections to the database.
 
             This function also closes the asynchronous worker loop making
             the NominatimAPI object unusable.
         """
     def close(self) -> None:
         """ Close all active connections to the database.
 
             This function also closes the asynchronous worker loop making
             the NominatimAPI object unusable.
         """
-        self._loop.run_until_complete(self._async_api.close())
-        self._loop.close()
+        if not self._loop.is_closed():
+            self._loop.run_until_complete(self._async_api.close())
+            self._loop.close()
+
+    def __enter__(self) -> 'NominatimAPI':
+        return self
 
 
+    def __exit__(self, *_: Any) -> None:
+        self.close()
 
     @property
     def config(self) -> Configuration:
 
     @property
     def config(self) -> Configuration:
@@ -405,7 +405,6 @@ class NominatimAPI:
         """
         return self._loop.run_until_complete(self._async_api.status())
 
         """
         return self._loop.run_until_complete(self._async_api.status())
 
-
     def details(self, place: ntyp.PlaceRef, **params: Any) -> Optional[DetailedResult]:
         """ Get detailed information about a place in the database.
 
     def details(self, place: ntyp.PlaceRef, **params: Any) -> Optional[DetailedResult]:
         """ Get detailed information about a place in the database.
 
@@ -488,7 +487,6 @@ class NominatimAPI:
         """
         return self._loop.run_until_complete(self._async_api.details(place, **params))
 
         """
         return self._loop.run_until_complete(self._async_api.details(place, **params))
 
-
     def lookup(self, places: Sequence[ntyp.PlaceRef], **params: Any) -> SearchResults:
         """ Get simple information about a list of places.
 
     def lookup(self, places: Sequence[ntyp.PlaceRef], **params: Any) -> SearchResults:
         """ Get simple information about a list of places.
 
@@ -565,7 +563,6 @@ class NominatimAPI:
         """
         return self._loop.run_until_complete(self._async_api.lookup(places, **params))
 
         """
         return self._loop.run_until_complete(self._async_api.lookup(places, **params))
 
-
     def reverse(self, coord: ntyp.AnyPoint, **params: Any) -> Optional[ReverseResult]:
         """ Find a place by its coordinates. Also known as reverse geocoding.
 
     def reverse(self, coord: ntyp.AnyPoint, **params: Any) -> Optional[ReverseResult]:
         """ Find a place by its coordinates. Also known as reverse geocoding.
 
@@ -647,7 +644,6 @@ class NominatimAPI:
         """
         return self._loop.run_until_complete(self._async_api.reverse(coord, **params))
 
         """
         return self._loop.run_until_complete(self._async_api.reverse(coord, **params))
 
-
     def search(self, query: str, **params: Any) -> SearchResults:
         """ Find a place by free-text search. Also known as forward geocoding.
 
     def search(self, query: str, **params: Any) -> SearchResults:
         """ Find a place by free-text search. Also known as forward geocoding.
 
@@ -747,8 +743,6 @@ class NominatimAPI:
         return self._loop.run_until_complete(
                    self._async_api.search(query, **params))
 
         return self._loop.run_until_complete(
                    self._async_api.search(query, **params))
 
-
-    # pylint: disable=too-many-arguments
     def search_address(self, amenity: Optional[str] = None,
                        street: Optional[str] = None,
                        city: Optional[str] = None,
     def search_address(self, amenity: Optional[str] = None,
                        street: Optional[str] = None,
                        city: Optional[str] = None,
@@ -866,7 +860,6 @@ class NominatimAPI:
                    self._async_api.search_address(amenity, street, city, county,
                                                   state, country, postalcode, **params))
 
                    self._async_api.search_address(amenity, street, city, county,
                                                   state, country, postalcode, **params))
 
-
     def search_category(self, categories: List[Tuple[str, str]],
                         near_query: Optional[str] = None,
                         **params: Any) -> SearchResults:
     def search_category(self, categories: List[Tuple[str, str]],
                         near_query: Optional[str] = None,
                         **params: Any) -> SearchResults: