From 26f30bff283e63597e0f349242ec83ae63844725 Mon Sep 17 00:00:00 2001 From: Sarah Hoffmann Date: Tue, 5 Jul 2022 10:46:55 +0200 Subject: [PATCH] add type annotation to DB utils As a cursor is needed as type, make this a public type. --- nominatim/db/connection.py | 8 ++++---- nominatim/db/utils.py | 16 ++++++++-------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/nominatim/db/connection.py b/nominatim/db/connection.py index 21c4f696..5a1b4695 100644 --- a/nominatim/db/connection.py +++ b/nominatim/db/connection.py @@ -22,7 +22,7 @@ from nominatim.errors import UsageError LOG = logging.getLogger() -class _Cursor(psycopg2.extras.DictCursor): +class Cursor(psycopg2.extras.DictCursor): """ A cursor returning dict-like objects and providing specialised execution functions. """ @@ -82,18 +82,18 @@ class Connection(psycopg2.extensions.connection): adds convenience functions for administrating the database. """ @overload # type: ignore[override] - def cursor(self) -> _Cursor: + def cursor(self) -> Cursor: ... @overload - def cursor(self, name: str) -> _Cursor: + def cursor(self, name: str) -> Cursor: ... @overload def cursor(self, cursor_factory: Callable[..., T_cursor]) -> T_cursor: ... - def cursor(self, cursor_factory = _Cursor, **kwargs): # type: ignore + def cursor(self, cursor_factory = Cursor, **kwargs): # type: ignore """ Return a new cursor. By default the specialised cursor is returned. """ return super().cursor(cursor_factory=cursor_factory, **kwargs) diff --git a/nominatim/db/utils.py b/nominatim/db/utils.py index 461cb662..9a7b4f16 100644 --- a/nominatim/db/utils.py +++ b/nominatim/db/utils.py @@ -7,14 +7,14 @@ """ Helper functions for handling DB accesses. """ -from typing import IO, Optional, Union +from typing import IO, Optional, Union, Any, Iterable import subprocess import logging import gzip import io from pathlib import Path -from nominatim.db.connection import get_pg_env +from nominatim.db.connection import get_pg_env, Cursor from nominatim.errors import UsageError LOG = logging.getLogger() @@ -84,20 +84,20 @@ class CopyBuffer: """ Data collector for the copy_from command. """ - def __init__(self): + def __init__(self) -> None: self.buffer = io.StringIO() - def __enter__(self): + def __enter__(self) -> 'CopyBuffer': return self - def __exit__(self, exc_type, exc_value, traceback): + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: if self.buffer is not None: self.buffer.close() - def add(self, *data): + def add(self, *data: Any) -> None: """ Add another row of data to the copy buffer. """ first = True @@ -113,9 +113,9 @@ class CopyBuffer: self.buffer.write('\n') - def copy_out(self, cur, table, columns=None): + def copy_out(self, cur: Cursor, table: str, columns: Optional[Iterable[str]] = None) -> None: """ Copy all collected data into the given table. """ if self.buffer.tell() > 0: self.buffer.seek(0) - cur.copy_from(self.buffer, table, columns=columns) + cur.copy_from(self.buffer, table, columns=columns) # type: ignore[no-untyped-call] -- 2.39.5