"""
Specialised connection and cursor functions.
"""
-from typing import Optional, Any, Callable, ContextManager, Dict, cast, overload, Tuple, Iterable
-import contextlib
+from typing import Optional, Any, Dict, Tuple
import logging
import os
-import psycopg2
-import psycopg2.extensions
-import psycopg2.extras
-from psycopg2 import sql as pysql
+import psycopg
+import psycopg.types.hstore
+from psycopg import sql as pysql
-from ..typing import SysEnv, Query, T_cursor
+from ..typing import SysEnv
from ..errors import UsageError
LOG = logging.getLogger()
-class Cursor(psycopg2.extras.DictCursor):
- """ A cursor returning dict-like objects and providing specialised
- execution functions.
- """
- # pylint: disable=arguments-renamed,arguments-differ
- def execute(self, query: Query, args: Any = None) -> None:
- """ Query execution that logs the SQL query when debugging is enabled.
- """
- if LOG.isEnabledFor(logging.DEBUG):
- LOG.debug(self.mogrify(query, args).decode('utf-8'))
-
- super().execute(query, args)
-
-
- def execute_values(self, sql: Query, argslist: Iterable[Tuple[Any, ...]],
- template: Optional[Query] = None) -> None:
- """ Wrapper for the psycopg2 convenience function to execute
- SQL for a list of values.
- """
- LOG.debug("SQL execute_values(%s, %s)", sql, argslist)
-
- psycopg2.extras.execute_values(self, sql, argslist, template=template)
+Cursor = psycopg.Cursor[Any]
+Connection = psycopg.Connection[Any]
- def scalar(self, sql: Query, args: Any = None) -> Any:
- """ Execute query that returns a single value. The value is returned.
- If the query yields more than one row, a ValueError is raised.
- """
- self.execute(sql, args)
+def execute_scalar(conn: Connection, sql: psycopg.abc.Query, args: Any = None) -> Any:
+ """ Execute query that returns a single value. The value is returned.
+ If the query yields more than one row, a ValueError is raised.
+ """
+ with conn.cursor(row_factory=psycopg.rows.tuple_row) as cur:
+ cur.execute(sql, args)
- if self.rowcount != 1:
+ if cur.rowcount != 1:
raise RuntimeError("Query did not return a single row.")
- result = self.fetchone()
- assert result is not None
-
- return result[0]
+ result = cur.fetchone()
+ assert result is not None
+ return result[0]
- def drop_table(self, name: str, if_exists: bool = True, cascade: bool = False) -> None:
- """ Drop the table with the given name.
- Set `if_exists` to False if a non-existent table should raise
- an exception instead of just being ignored. If 'cascade' is set
- to True then all dependent tables are deleted as well.
- """
- sql = 'DROP TABLE '
- if if_exists:
- sql += 'IF EXISTS '
- sql += '{}'
- if cascade:
- sql += ' CASCADE'
- self.execute(pysql.SQL(sql).format(pysql.Identifier(name)))
+def table_exists(conn: Connection, table: str) -> bool:
+ """ Check that a table with the given name exists in the database.
+ """
+ num = execute_scalar(
+ conn,
+ """SELECT count(*) FROM pg_tables
+ WHERE tablename = %s and schemaname = 'public'""", (table, ))
+ return num == 1 if isinstance(num, int) else False
-class Connection(psycopg2.extensions.connection):
- """ A connection that provides the specialised cursor by default and
- adds convenience functions for administrating the database.
+def table_has_column(conn: Connection, table: str, column: str) -> bool:
+ """ Check if the table 'table' exists and has a column with name 'column'.
"""
- @overload # type: ignore[override]
- def cursor(self) -> Cursor:
- ...
-
- @overload
- def cursor(self, name: str) -> Cursor:
- ...
-
- @overload
- def cursor(self, cursor_factory: Callable[..., T_cursor]) -> T_cursor:
- ...
-
- def cursor(self, cursor_factory = Cursor, **kwargs): # type: ignore
- """ Return a new cursor. By default the specialised cursor is returned.
- """
- return super().cursor(cursor_factory=cursor_factory, **kwargs)
-
-
- def table_exists(self, table: str) -> bool:
- """ Check that a table with the given name exists in the database.
- """
- with self.cursor() as cur:
- num = cur.scalar("""SELECT count(*) FROM pg_tables
- WHERE tablename = %s and schemaname = 'public'""", (table, ))
- return num == 1 if isinstance(num, int) else False
-
-
- def table_has_column(self, table: str, column: str) -> bool:
- """ Check if the table 'table' exists and has a column with name 'column'.
- """
- with self.cursor() as cur:
- has_column = cur.scalar("""SELECT count(*) FROM information_schema.columns
- WHERE table_name = %s
- and column_name = %s""",
- (table, column))
- return has_column > 0 if isinstance(has_column, int) else False
-
-
- def index_exists(self, index: str, table: Optional[str] = None) -> bool:
- """ Check that an index with the given name exists in the database.
- If table is not None then the index must relate to the given
- table.
- """
- with self.cursor() as cur:
- cur.execute("""SELECT tablename FROM pg_indexes
- WHERE indexname = %s and schemaname = 'public'""", (index, ))
- if cur.rowcount == 0:
- return False
+ has_column = execute_scalar(conn,
+ """SELECT count(*) FROM information_schema.columns
+ WHERE table_name = %s and column_name = %s""",
+ (table, column))
+ return has_column > 0 if isinstance(has_column, int) else False
- if table is not None:
- row = cur.fetchone()
- if row is None or not isinstance(row[0], str):
- return False
- return row[0] == table
- return True
+def index_exists(conn: Connection, index: str, table: Optional[str] = None) -> bool:
+ """ Check that an index with the given name exists in the database.
+ If table is not None then the index must relate to the given
+ table.
+ """
+ with conn.cursor() as cur:
+ cur.execute("""SELECT tablename FROM pg_indexes
+ WHERE indexname = %s and schemaname = 'public'""", (index, ))
+ if cur.rowcount == 0:
+ return False
+
+ if table is not None:
+ row = cur.fetchone()
+ if row is None or not isinstance(row[0], str):
+ return False
+ return row[0] == table
+ return True
- def drop_table(self, name: str, if_exists: bool = True, cascade: bool = False) -> None:
- """ Drop the table with the given name.
- Set `if_exists` to False if a non-existent table should raise
- an exception instead of just being ignored.
- """
- with self.cursor() as cur:
- cur.drop_table(name, if_exists, cascade)
- self.commit()
+def drop_tables(conn: Connection, *names: str,
+ if_exists: bool = True, cascade: bool = False) -> None:
+ """ Drop one or more tables with the given names.
+ Set `if_exists` to False if a non-existent table should raise
+ an exception instead of just being ignored. `cascade` will cause
+ depended objects to be dropped as well.
+ The caller needs to take care of committing the change.
+ """
+ sql = pysql.SQL('DROP TABLE%s{}%s' % (
+ ' IF EXISTS ' if if_exists else ' ',
+ ' CASCADE' if cascade else ''))
- def server_version_tuple(self) -> Tuple[int, int]:
- """ Return the server version as a tuple of (major, minor).
- Converts correctly for pre-10 and post-10 PostgreSQL versions.
- """
- version = self.server_version
- if version < 100000:
- return (int(version / 10000), int((version % 10000) / 100))
+ with conn.cursor() as cur:
+ for name in names:
+ cur.execute(sql.format(pysql.Identifier(name)))
- return (int(version / 10000), version % 10000)
+def server_version_tuple(conn: Connection) -> Tuple[int, int]:
+ """ Return the server version as a tuple of (major, minor).
+ Converts correctly for pre-10 and post-10 PostgreSQL versions.
+ """
+ version = conn.info.server_version
+ if version < 100000:
+ return (int(version / 10000), int((version % 10000) / 100))
- def postgis_version_tuple(self) -> Tuple[int, int]:
- """ Return the postgis version installed in the database as a
- tuple of (major, minor). Assumes that the PostGIS extension
- has been installed already.
- """
- with self.cursor() as cur:
- version = cur.scalar('SELECT postgis_lib_version()')
+ return (int(version / 10000), version % 10000)
- version_parts = version.split('.')
- if len(version_parts) < 2:
- raise UsageError(f"Error fetching Postgis version. Bad format: {version}")
- return (int(version_parts[0]), int(version_parts[1]))
+def postgis_version_tuple(conn: Connection) -> Tuple[int, int]:
+ """ Return the postgis version installed in the database as a
+ tuple of (major, minor). Assumes that the PostGIS extension
+ has been installed already.
+ """
+ version = execute_scalar(conn, 'SELECT postgis_lib_version()')
+ version_parts = version.split('.')
+ if len(version_parts) < 2:
+ raise UsageError(f"Error fetching Postgis version. Bad format: {version}")
- def extension_loaded(self, extension_name: str) -> bool:
- """ Return True if the hstore extension is loaded in the database.
- """
- with self.cursor() as cur:
- cur.execute('SELECT extname FROM pg_extension WHERE extname = %s', (extension_name, ))
- return cur.rowcount > 0
+ return (int(version_parts[0]), int(version_parts[1]))
-class ConnectionContext(ContextManager[Connection]):
- """ Context manager of the connection that also provides direct access
- to the underlying connection.
+def register_hstore(conn: Connection) -> None:
+ """ Register the hstore type with psycopg for the connection.
"""
- connection: Connection
+ info = psycopg.types.TypeInfo.fetch(conn, "hstore")
+ if info is None:
+ raise RuntimeError('Hstore extension is requested but not installed.')
+ psycopg.types.hstore.register_hstore(info, conn)
-def connect(dsn: str) -> ConnectionContext:
+
+def connect(dsn: str, **kwargs: Any) -> Connection:
""" Open a connection to the database using the specialised connection
factory. The returned object may be used in conjunction with 'with'.
When used outside a context manager, use the `connection` attribute
to get the connection.
"""
try:
- conn = psycopg2.connect(dsn, connection_factory=Connection)
- ctxmgr = cast(ConnectionContext, contextlib.closing(conn))
- ctxmgr.connection = conn
- return ctxmgr
- except psycopg2.OperationalError as err:
+ return psycopg.connect(dsn, row_factory=psycopg.rows.namedtuple_row, **kwargs)
+ except psycopg.OperationalError as err:
raise UsageError(f"Cannot connect to database: {err}") from err
"""
env = dict(base_env if base_env is not None else os.environ)
- for param, value in psycopg2.extensions.parse_dsn(dsn).items():
+ for param, value in psycopg.conninfo.conninfo_to_dict(dsn).items():
if param in _PG_CONNECTION_STRINGS:
- env[_PG_CONNECTION_STRINGS[param]] = value
+ env[_PG_CONNECTION_STRINGS[param]] = str(value)
else:
LOG.error("Unknown connection parameter '%s' ignored.", param)
return env
+
+
+async def run_async_query(dsn: str, query: psycopg.abc.Query) -> None:
+ """ Open a connection to the database and run a single query
+ asynchronously.
+ """
+ async with await psycopg.AsyncConnection.connect(dsn) as aconn:
+ await aconn.execute(query)