]> git.openstreetmap.org Git - nominatim.git/blobdiff - src/nominatim_db/config.py
Merge pull request #3525 from lonvia/project-dir-less-library
[nominatim.git] / src / nominatim_db / config.py
index c4264f0d68ef17c0a182a1063c0a970862cf5aec..2cab02377593a6ce71286199105faf49e12732ff 100644 (file)
@@ -7,7 +7,7 @@
 """
 Nominatim configuration accessor.
 """
-from typing import Dict, Any, List, Mapping, Optional
+from typing import Union, Dict, Any, List, Mapping, Optional
 import importlib.util
 import logging
 import os
@@ -18,10 +18,7 @@ import yaml
 
 from dotenv import dotenv_values
 
-try:
-    from psycopg2.extensions import parse_dsn
-except ModuleNotFoundError:
-    from psycopg.conninfo import conninfo_to_dict as parse_dsn # type: ignore[assignment]
+from psycopg.conninfo import conninfo_to_dict
 
 from .typing import StrPath
 from .errors import UsageError
@@ -62,15 +59,17 @@ class Configuration:
         other than string.
     """
 
-    def __init__(self, project_dir: Optional[Path],
+    def __init__(self, project_dir: Optional[Union[Path, str]],
                  environ: Optional[Mapping[str, str]] = None) -> None:
         self.environ = environ or os.environ
-        self.project_dir = project_dir
         self.config_dir = paths.CONFIG_DIR
         self._config = dotenv_values(str(self.config_dir / 'env.defaults'))
-        if self.project_dir is not None and (self.project_dir / '.env').is_file():
-            self.project_dir = self.project_dir.resolve()
-            self._config.update(dotenv_values(str(self.project_dir / '.env')))
+        if project_dir is not None:
+            self.project_dir: Optional[Path] = Path(project_dir).resolve()
+            if (self.project_dir / '.env').is_file():
+                self._config.update(dotenv_values(str(self.project_dir / '.env')))
+        else:
+            self.project_dir = None
 
         class _LibDirs:
             module: Path
@@ -198,7 +197,7 @@ class Configuration:
         return dsn
 
 
-    def get_database_params(self) -> Mapping[str, str]:
+    def get_database_params(self) -> Mapping[str, Union[str, int, None]]:
         """ Get the configured parameters for the database connection
             as a mapping.
         """
@@ -207,7 +206,7 @@ class Configuration:
         if dsn.startswith('pgsql:'):
             return dict((p.split('=', 1) for p in dsn[6:].split(';')))
 
-        return parse_dsn(dsn)
+        return conninfo_to_dict(dsn)
 
 
     def get_import_style_file(self) -> Path: