X-Git-Url: https://git.openstreetmap.org./nominatim.git/blobdiff_plain/4da4cbfe27a576ae011430b2de205c74435e241b..2d115ea4128a869c335d205056174b14febc7bbd:/src/nominatim_db/config.py?ds=sidebyside diff --git a/src/nominatim_db/config.py b/src/nominatim_db/config.py index c4264f0d..b220b5c7 100644 --- a/src/nominatim_db/config.py +++ b/src/nominatim_db/config.py @@ -7,7 +7,7 @@ """ Nominatim configuration accessor. """ -from typing import Dict, Any, List, Mapping, Optional +from typing import Union, Dict, Any, List, Mapping, Optional import importlib.util import logging import os @@ -18,10 +18,7 @@ import yaml from dotenv import dotenv_values -try: - from psycopg2.extensions import parse_dsn -except ModuleNotFoundError: - from psycopg.conninfo import conninfo_to_dict as parse_dsn # type: ignore[assignment] +from psycopg.conninfo import conninfo_to_dict from .typing import StrPath from .errors import UsageError @@ -62,20 +59,20 @@ class Configuration: other than string. """ - def __init__(self, project_dir: Optional[Path], + def __init__(self, project_dir: Optional[Union[Path, str]], environ: Optional[Mapping[str, str]] = None) -> None: - self.environ = environ or os.environ - self.project_dir = project_dir + self.environ = os.environ if environ is None else environ self.config_dir = paths.CONFIG_DIR self._config = dotenv_values(str(self.config_dir / 'env.defaults')) - if self.project_dir is not None and (self.project_dir / '.env').is_file(): - self.project_dir = self.project_dir.resolve() - self._config.update(dotenv_values(str(self.project_dir / '.env'))) + if project_dir is not None: + self.project_dir: Optional[Path] = Path(project_dir).resolve() + if (self.project_dir / '.env').is_file(): + self._config.update(dotenv_values(str(self.project_dir / '.env'))) + else: + self.project_dir = None class _LibDirs: - module: Path osm2pgsql: Path - php = paths.PHPLIB_DIR sql = paths.SQLLIB_DIR data = paths.DATA_DIR @@ -198,7 +195,7 @@ class Configuration: return dsn - def get_database_params(self) -> Mapping[str, str]: + def get_database_params(self) -> Mapping[str, Union[str, int, None]]: """ Get the configured parameters for the database connection as a mapping. """ @@ -207,7 +204,7 @@ class Configuration: if dsn.startswith('pgsql:'): return dict((p.split('=', 1) for p in dsn[6:].split(';'))) - return parse_dsn(dsn) + return conninfo_to_dict(dsn) def get_import_style_file(self) -> Path: