"""
Specialised connection and cursor functions.
"""
-from typing import List, Optional, Any, Callable, ContextManager, Dict, cast, overload, Tuple
+from typing import Optional, Any, Callable, ContextManager, Dict, cast, overload, Tuple, Iterable
import contextlib
import logging
import os
LOG = logging.getLogger()
-class _Cursor(psycopg2.extras.DictCursor):
+class Cursor(psycopg2.extras.DictCursor):
""" A cursor returning dict-like objects and providing specialised
execution functions.
"""
""" Query execution that logs the SQL query when debugging is enabled.
"""
if LOG.isEnabledFor(logging.DEBUG):
- LOG.debug(self.mogrify(query, args).decode('utf-8')) # type: ignore
+ LOG.debug(self.mogrify(query, args).decode('utf-8'))
super().execute(query, args)
- def execute_values(self, sql: Query, argslist: List[Any],
- template: Optional[str] = None) -> None:
+ def execute_values(self, sql: Query, argslist: Iterable[Tuple[Any, ...]],
+ template: Optional[Query] = None) -> None:
""" Wrapper for the psycopg2 convenience function to execute
SQL for a list of values.
"""
if self.rowcount != 1:
raise RuntimeError("Query did not return a single row.")
- result = self.fetchone() # type: ignore
+ result = self.fetchone()
assert result is not None
return result[0]
def drop_table(self, name: str, if_exists: bool = True, cascade: bool = False) -> None:
""" Drop the table with the given name.
- Set `if_exists` to False if a non-existant table should raise
+ Set `if_exists` to False if a non-existent table should raise
an exception instead of just being ignored. If 'cascade' is set
to True then all dependent tables are deleted as well.
"""
if cascade:
sql += ' CASCADE'
- self.execute(pysql.SQL(sql).format(pysql.Identifier(name))) # type: ignore
+ self.execute(pysql.SQL(sql).format(pysql.Identifier(name)))
-class _Connection(psycopg2.extensions.connection):
+class Connection(psycopg2.extensions.connection):
""" A connection that provides the specialised cursor by default and
adds convenience functions for administrating the database.
"""
@overload # type: ignore[override]
- def cursor(self) -> _Cursor:
+ def cursor(self) -> Cursor:
...
@overload
- def cursor(self, name: str) -> _Cursor:
+ def cursor(self, name: str) -> Cursor:
...
@overload
def cursor(self, cursor_factory: Callable[..., T_cursor]) -> T_cursor:
...
- def cursor(self, cursor_factory = _Cursor, **kwargs): # type: ignore
+ def cursor(self, cursor_factory = Cursor, **kwargs): # type: ignore
""" Return a new cursor. By default the specialised cursor is returned.
"""
return super().cursor(cursor_factory=cursor_factory, **kwargs)
return False
if table is not None:
- row = cur.fetchone() # type: ignore
+ row = cur.fetchone()
if row is None or not isinstance(row[0], str):
return False
return row[0] == table
def drop_table(self, name: str, if_exists: bool = True, cascade: bool = False) -> None:
""" Drop the table with the given name.
- Set `if_exists` to False if a non-existant table should raise
+ Set `if_exists` to False if a non-existent table should raise
an exception instead of just being ignored.
"""
with self.cursor() as cur:
return (int(version_parts[0]), int(version_parts[1]))
-class _ConnectionContext(ContextManager[_Connection]):
- connection: _Connection
-def connect(dsn: str) -> _ConnectionContext:
+ def extension_loaded(self, extension_name: str) -> bool:
+ """ Return True if the hstore extension is loaded in the database.
+ """
+ with self.cursor() as cur:
+ cur.execute('SELECT extname FROM pg_extension WHERE extname = %s', (extension_name, ))
+ return cur.rowcount > 0
+
+
+class ConnectionContext(ContextManager[Connection]):
+ """ Context manager of the connection that also provides direct access
+ to the underlying connection.
+ """
+ connection: Connection
+
+def connect(dsn: str) -> ConnectionContext:
""" Open a connection to the database using the specialised connection
factory. The returned object may be used in conjunction with 'with'.
When used outside a context manager, use the `connection` attribute
to get the connection.
"""
try:
- conn = psycopg2.connect(dsn, connection_factory=_Connection)
- ctxmgr = cast(_ConnectionContext, contextlib.closing(conn))
- ctxmgr.connection = cast(_Connection, conn)
+ conn = psycopg2.connect(dsn, connection_factory=Connection)
+ ctxmgr = cast(ConnectionContext, contextlib.closing(conn))
+ ctxmgr.connection = conn
return ctxmgr
except psycopg2.OperationalError as err:
raise UsageError(f"Cannot connect to database: {err}") from err
"""
env = dict(base_env if base_env is not None else os.environ)
- for param, value in psycopg2.extensions.parse_dsn(dsn).items(): # type: ignore
+ for param, value in psycopg2.extensions.parse_dsn(dsn).items():
if param in _PG_CONNECTION_STRINGS:
env[_PG_CONNECTION_STRINGS[param]] = value
else: