]> git.openstreetmap.org Git - nominatim.git/commitdiff
add type annotation to DB utils
authorSarah Hoffmann <lonvia@denofr.de>
Tue, 5 Jul 2022 08:46:55 +0000 (10:46 +0200)
committerSarah Hoffmann <lonvia@denofr.de>
Mon, 18 Jul 2022 07:47:57 +0000 (09:47 +0200)
As a cursor is needed as type, make this a public type.

nominatim/db/connection.py
nominatim/db/utils.py

index 21c4f696f684675957c6348b3305a204996f330c..5a1b46959d457b6ce3f9dd7e4d3e6b00590f9434 100644 (file)
@@ -22,7 +22,7 @@ from nominatim.errors import UsageError
 
 LOG = logging.getLogger()
 
 
 LOG = logging.getLogger()
 
-class _Cursor(psycopg2.extras.DictCursor):
+class Cursor(psycopg2.extras.DictCursor):
     """ A cursor returning dict-like objects and providing specialised
         execution functions.
     """
     """ A cursor returning dict-like objects and providing specialised
         execution functions.
     """
@@ -82,18 +82,18 @@ class Connection(psycopg2.extensions.connection):
         adds convenience functions for administrating the database.
     """
     @overload # type: ignore[override]
         adds convenience functions for administrating the database.
     """
     @overload # type: ignore[override]
-    def cursor(self) -> _Cursor:
+    def cursor(self) -> Cursor:
         ...
 
     @overload
         ...
 
     @overload
-    def cursor(self, name: str) -> _Cursor:
+    def cursor(self, name: str) -> Cursor:
         ...
 
     @overload
     def cursor(self, cursor_factory: Callable[..., T_cursor]) -> T_cursor:
         ...
 
         ...
 
     @overload
     def cursor(self, cursor_factory: Callable[..., T_cursor]) -> T_cursor:
         ...
 
-    def cursor(self, cursor_factory  = _Cursor, **kwargs): # type: ignore
+    def cursor(self, cursor_factory  = Cursor, **kwargs): # type: ignore
         """ Return a new cursor. By default the specialised cursor is returned.
         """
         return super().cursor(cursor_factory=cursor_factory, **kwargs)
         """ Return a new cursor. By default the specialised cursor is returned.
         """
         return super().cursor(cursor_factory=cursor_factory, **kwargs)
index 461cb662415910844b87042aa76df2754046e13f..9a7b4f164787b8abb03831477fe3b876357e9b25 100644 (file)
@@ -7,14 +7,14 @@
 """
 Helper functions for handling DB accesses.
 """
 """
 Helper functions for handling DB accesses.
 """
-from typing import IO, Optional, Union
+from typing import IO, Optional, Union, Any, Iterable
 import subprocess
 import logging
 import gzip
 import io
 from pathlib import Path
 
 import subprocess
 import logging
 import gzip
 import io
 from pathlib import Path
 
-from nominatim.db.connection import get_pg_env
+from nominatim.db.connection import get_pg_env, Cursor
 from nominatim.errors import UsageError
 
 LOG = logging.getLogger()
 from nominatim.errors import UsageError
 
 LOG = logging.getLogger()
@@ -84,20 +84,20 @@ class CopyBuffer:
     """ Data collector for the copy_from command.
     """
 
     """ Data collector for the copy_from command.
     """
 
-    def __init__(self):
+    def __init__(self) -> None:
         self.buffer = io.StringIO()
 
 
         self.buffer = io.StringIO()
 
 
-    def __enter__(self):
+    def __enter__(self) -> 'CopyBuffer':
         return self
 
 
         return self
 
 
-    def __exit__(self, exc_type, exc_value, traceback):
+    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
         if self.buffer is not None:
             self.buffer.close()
 
 
         if self.buffer is not None:
             self.buffer.close()
 
 
-    def add(self, *data):
+    def add(self, *data: Any) -> None:
         """ Add another row of data to the copy buffer.
         """
         first = True
         """ Add another row of data to the copy buffer.
         """
         first = True
@@ -113,9 +113,9 @@ class CopyBuffer:
         self.buffer.write('\n')
 
 
         self.buffer.write('\n')
 
 
-    def copy_out(self, cur, table, columns=None):
+    def copy_out(self, cur: Cursor, table: str, columns: Optional[Iterable[str]] = None) -> None:
         """ Copy all collected data into the given table.
         """
         if self.buffer.tell() > 0:
             self.buffer.seek(0)
         """ Copy all collected data into the given table.
         """
         if self.buffer.tell() > 0:
             self.buffer.seek(0)
-            cur.copy_from(self.buffer, table, columns=columns)
+            cur.copy_from(self.buffer, table, columns=columns) # type: ignore[no-untyped-call]