"""
Specialised connection and cursor functions.
"""
-from typing import List, Optional, Any, Callable, ContextManager, Dict, cast, overload, Tuple
+from typing import Optional, Any, Callable, ContextManager, Dict, cast, overload, Tuple, Iterable
import contextlib
import logging
import os
""" Query execution that logs the SQL query when debugging is enabled.
"""
if LOG.isEnabledFor(logging.DEBUG):
- LOG.debug(self.mogrify(query, args).decode('utf-8')) # type: ignore[no-untyped-call]
+ LOG.debug(self.mogrify(query, args).decode('utf-8'))
super().execute(query, args)
- def execute_values(self, sql: Query, argslist: List[Any],
- template: Optional[str] = None) -> None:
+ def execute_values(self, sql: Query, argslist: Iterable[Tuple[Any, ...]],
+ template: Optional[Query] = None) -> None:
""" Wrapper for the psycopg2 convenience function to execute
SQL for a list of values.
"""
if self.rowcount != 1:
raise RuntimeError("Query did not return a single row.")
- result = self.fetchone() # type: ignore[no-untyped-call]
+ result = self.fetchone()
assert result is not None
return result[0]
def drop_table(self, name: str, if_exists: bool = True, cascade: bool = False) -> None:
""" Drop the table with the given name.
- Set `if_exists` to False if a non-existant table should raise
+ Set `if_exists` to False if a non-existent table should raise
an exception instead of just being ignored. If 'cascade' is set
to True then all dependent tables are deleted as well.
"""
return False
if table is not None:
- row = cur.fetchone() # type: ignore[no-untyped-call]
+ row = cur.fetchone()
if row is None or not isinstance(row[0], str):
return False
return row[0] == table
def drop_table(self, name: str, if_exists: bool = True, cascade: bool = False) -> None:
""" Drop the table with the given name.
- Set `if_exists` to False if a non-existant table should raise
+ Set `if_exists` to False if a non-existent table should raise
an exception instead of just being ignored.
"""
with self.cursor() as cur:
return (int(version_parts[0]), int(version_parts[1]))
+
+ def extension_loaded(self, extension_name: str) -> bool:
+ """ Return True if the hstore extension is loaded in the database.
+ """
+ with self.cursor() as cur:
+ cur.execute('SELECT extname FROM pg_extension WHERE extname = %s', (extension_name, ))
+ return cur.rowcount > 0
+
+
class ConnectionContext(ContextManager[Connection]):
""" Context manager of the connection that also provides direct access
to the underlying connection.
try:
conn = psycopg2.connect(dsn, connection_factory=Connection)
ctxmgr = cast(ConnectionContext, contextlib.closing(conn))
- ctxmgr.connection = cast(Connection, conn)
+ ctxmgr.connection = conn
return ctxmgr
except psycopg2.OperationalError as err:
raise UsageError(f"Cannot connect to database: {err}") from err
def get_pg_env(dsn: str,
base_env: Optional[SysEnv] = None) -> Dict[str, str]:
""" Return a copy of `base_env` with the environment variables for
- PostgresSQL set up from the given database connection string.
+ PostgreSQL set up from the given database connection string.
If `base_env` is None, then the OS environment is used as a base
environment.
"""
env = dict(base_env if base_env is not None else os.environ)
- for param, value in psycopg2.extensions.parse_dsn(dsn).items(): # type: ignore
+ for param, value in psycopg2.extensions.parse_dsn(dsn).items():
if param in _PG_CONNECTION_STRINGS:
env[_PG_CONNECTION_STRINGS[param]] = value
else: