]> git.openstreetmap.org Git - nominatim.git/blobdiff - nominatim/db/connection.py
Merge pull request #2476 from lonvia/harmonize-configuration-file-settings
[nominatim.git] / nominatim / db / connection.py
index 68e988f6a4c57023e23549c86946f1df28a2ee9a..1319ac16ea21c6e3d9424ea3bd205dd369768b4d 100644 (file)
@@ -8,8 +8,9 @@ import os
 import psycopg2
 import psycopg2.extensions
 import psycopg2.extras
 import psycopg2
 import psycopg2.extensions
 import psycopg2.extras
+from psycopg2 import sql as pysql
 
 
-from ..errors import UsageError
+from nominatim.errors import UsageError
 
 LOG = logging.getLogger()
 
 
 LOG = logging.getLogger()
 
@@ -25,6 +26,16 @@ class _Cursor(psycopg2.extras.DictCursor):
 
         super().execute(query, args)
 
 
         super().execute(query, args)
 
+
+    def execute_values(self, sql, argslist, template=None):
+        """ Wrapper for the psycopg2 convenience function to execute
+            SQL for a list of values.
+        """
+        LOG.debug("SQL execute_values(%s, %s)", sql, argslist)
+
+        psycopg2.extras.execute_values(self, sql, argslist, template=template)
+
+
     def scalar(self, sql, args=None):
         """ Execute query that returns a single value. The value is returned.
             If the query yields more than one row, a ValueError is raised.
     def scalar(self, sql, args=None):
         """ Execute query that returns a single value. The value is returned.
             If the query yields more than one row, a ValueError is raised.
@@ -37,6 +48,22 @@ class _Cursor(psycopg2.extras.DictCursor):
         return self.fetchone()[0]
 
 
         return self.fetchone()[0]
 
 
+    def drop_table(self, name, if_exists=True, cascade=False):
+        """ Drop the table with the given name.
+            Set `if_exists` to False if a non-existant table should raise
+            an exception instead of just being ignored. If 'cascade' is set
+            to True then all dependent tables are deleted as well.
+        """
+        sql = 'DROP TABLE '
+        if if_exists:
+            sql += 'IF EXISTS '
+        sql += '{}'
+        if cascade:
+            sql += ' CASCADE'
+
+        self.execute(pysql.SQL(sql).format(pysql.Identifier(name)))
+
+
 class _Connection(psycopg2.extensions.connection):
     """ A connection that provides the specialised cursor by default and
         adds convenience functions for administrating the database.
 class _Connection(psycopg2.extensions.connection):
     """ A connection that provides the specialised cursor by default and
         adds convenience functions for administrating the database.
@@ -75,15 +102,37 @@ class _Connection(psycopg2.extensions.connection):
         return True
 
 
         return True
 
 
+    def drop_table(self, name, if_exists=True, cascade=False):
+        """ Drop the table with the given name.
+            Set `if_exists` to False if a non-existant table should raise
+            an exception instead of just being ignored.
+        """
+        with self.cursor() as cur:
+            cur.drop_table(name, if_exists, cascade)
+        self.commit()
+
+
     def server_version_tuple(self):
         """ Return the server version as a tuple of (major, minor).
             Converts correctly for pre-10 and post-10 PostgreSQL versions.
         """
         version = self.server_version
         if version < 100000:
     def server_version_tuple(self):
         """ Return the server version as a tuple of (major, minor).
             Converts correctly for pre-10 and post-10 PostgreSQL versions.
         """
         version = self.server_version
         if version < 100000:
-            return (version / 10000, (version % 10000) / 100)
+            return (int(version / 10000), (version % 10000) / 100)
+
+        return (int(version / 10000), version % 10000)
+
+
+    def postgis_version_tuple(self):
+        """ Return the postgis version installed in the database as a
+            tuple of (major, minor). Assumes that the PostGIS extension
+            has been installed already.
+        """
+        with self.cursor() as cur:
+            version = cur.scalar('SELECT postgis_lib_version()')
+
+        return tuple((int(x) for x in version.split('.')[:2]))
 
 
-        return (version / 10000, version % 10000)
 
 def connect(dsn):
     """ Open a connection to the database using the specialised connection
 
 def connect(dsn):
     """ Open a connection to the database using the specialised connection
@@ -123,7 +172,7 @@ _PG_CONNECTION_STRINGS = {
     'sslcrl': 'PGSSLCRL',
     'requirepeer': 'PGREQUIREPEER',
     'ssl_min_protocol_version': 'PGSSLMINPROTOCOLVERSION',
     'sslcrl': 'PGSSLCRL',
     'requirepeer': 'PGREQUIREPEER',
     'ssl_min_protocol_version': 'PGSSLMINPROTOCOLVERSION',
-    'ssl_min_protocol_version': 'PGSSLMAXPROTOCOLVERSION',
+    'ssl_max_protocol_version': 'PGSSLMAXPROTOCOLVERSION',
     'gssencmode': 'PGGSSENCMODE',
     'krbsrvname': 'PGKRBSRVNAME',
     'gsslib': 'PGGSSLIB',
     'gssencmode': 'PGGSSENCMODE',
     'krbsrvname': 'PGKRBSRVNAME',
     'gsslib': 'PGGSSLIB',
@@ -138,7 +187,7 @@ def get_pg_env(dsn, base_env=None):
         If `base_env` is None, then the OS environment is used as a base
         environment.
     """
         If `base_env` is None, then the OS environment is used as a base
         environment.
     """
-    env = base_env if base_env is not None else os.environ
+    env = dict(base_env if base_env is not None else os.environ)
 
     for param, value in psycopg2.extensions.parse_dsn(dsn).items():
         if param in _PG_CONNECTION_STRINGS:
 
     for param, value in psycopg2.extensions.parse_dsn(dsn).items():
         if param in _PG_CONNECTION_STRINGS: