]> git.openstreetmap.org Git - nominatim.git/commitdiff
add type annotations for indexer
authorSarah Hoffmann <lonvia@denofr.de>
Tue, 12 Jul 2022 16:40:51 +0000 (18:40 +0200)
committerSarah Hoffmann <lonvia@denofr.de>
Mon, 18 Jul 2022 07:47:57 +0000 (09:47 +0200)
nominatim/db/async_connection.py
nominatim/db/connection.py
nominatim/indexer/indexer.py
nominatim/indexer/progress.py
nominatim/indexer/runners.py
nominatim/typing.py

index 1c550198ec053be6e70e0ae25cce194d66e70e96..a2c8fe4de8a630e54001b7a08e0da41c79a9684d 100644 (file)
@@ -6,7 +6,7 @@
 # For a full list of authors see the git log.
 """ Non-blocking database connections.
 """
 # For a full list of authors see the git log.
 """ Non-blocking database connections.
 """
-from typing import Callable, Any, Optional, List, Iterator
+from typing import Callable, Any, Optional, Iterator, Sequence
 import logging
 import select
 import time
 import logging
 import select
 import time
@@ -22,7 +22,7 @@ try:
 except ImportError:
     __has_psycopg2_errors__ = False
 
 except ImportError:
     __has_psycopg2_errors__ = False
 
-from nominatim.typing import T_cursor
+from nominatim.typing import T_cursor, Query
 
 LOG = logging.getLogger()
 
 
 LOG = logging.getLogger()
 
@@ -65,8 +65,8 @@ class DBConnection:
                  ignore_sql_errors: bool = False) -> None:
         self.dsn = dsn
 
                  ignore_sql_errors: bool = False) -> None:
         self.dsn = dsn
 
-        self.current_query: Optional[str] = None
-        self.current_params: Optional[List[Any]] = None
+        self.current_query: Optional[Query] = None
+        self.current_params: Optional[Sequence[Any]] = None
         self.ignore_sql_errors = ignore_sql_errors
 
         self.conn: Optional['psycopg2.connection'] = None
         self.ignore_sql_errors = ignore_sql_errors
 
         self.conn: Optional['psycopg2.connection'] = None
@@ -128,7 +128,7 @@ class DBConnection:
                 self.current_query = None
                 return
 
                 self.current_query = None
                 return
 
-    def perform(self, sql: str, args: Optional[List[Any]] = None) -> None:
+    def perform(self, sql: Query, args: Optional[Sequence[Any]] = None) -> None:
         """ Send SQL query to the server. Returns immediately without
             blocking.
         """
         """ Send SQL query to the server. Returns immediately without
             blocking.
         """
index 5a1b46959d457b6ce3f9dd7e4d3e6b00590f9434..25ddcba4c887d46039ba765ff57952f9288852ba 100644 (file)
@@ -74,7 +74,7 @@ class Cursor(psycopg2.extras.DictCursor):
         if cascade:
             sql += ' CASCADE'
 
         if cascade:
             sql += ' CASCADE'
 
-        self.execute(pysql.SQL(sql).format(pysql.Identifier(name))) # type: ignore[no-untyped-call]
+        self.execute(pysql.SQL(sql).format(pysql.Identifier(name)))
 
 
 class Connection(psycopg2.extensions.connection):
 
 
 class Connection(psycopg2.extensions.connection):
index 555f8704a19c6796da4b97a724cd363d183f7f12..4f7675309cbaa91068f777a97789b5b2e809c5ac 100644 (file)
@@ -7,15 +7,18 @@
 """
 Main work horse for indexing (computing addresses) the database.
 """
 """
 Main work horse for indexing (computing addresses) the database.
 """
+from typing import Optional, Any, cast
 import logging
 import time
 
 import psycopg2.extras
 
 import logging
 import time
 
 import psycopg2.extras
 
+from nominatim.tokenizer.base import AbstractTokenizer
 from nominatim.indexer.progress import ProgressLogger
 from nominatim.indexer import runners
 from nominatim.db.async_connection import DBConnection, WorkerPool
 from nominatim.indexer.progress import ProgressLogger
 from nominatim.indexer import runners
 from nominatim.db.async_connection import DBConnection, WorkerPool
-from nominatim.db.connection import connect
+from nominatim.db.connection import connect, Connection, Cursor
+from nominatim.typing import DictCursorResults
 
 LOG = logging.getLogger()
 
 
 LOG = logging.getLogger()
 
@@ -23,10 +26,11 @@ LOG = logging.getLogger()
 class PlaceFetcher:
     """ Asynchronous connection that fetches place details for processing.
     """
 class PlaceFetcher:
     """ Asynchronous connection that fetches place details for processing.
     """
-    def __init__(self, dsn, setup_conn):
-        self.wait_time = 0
-        self.current_ids = None
-        self.conn = DBConnection(dsn, cursor_factory=psycopg2.extras.DictCursor)
+    def __init__(self, dsn: str, setup_conn: Connection) -> None:
+        self.wait_time = 0.0
+        self.current_ids: Optional[DictCursorResults] = None
+        self.conn: Optional[DBConnection] = DBConnection(dsn,
+                                               cursor_factory=psycopg2.extras.DictCursor)
 
         with setup_conn.cursor() as cur:
             # need to fetch those manually because register_hstore cannot
 
         with setup_conn.cursor() as cur:
             # need to fetch those manually because register_hstore cannot
@@ -37,7 +41,7 @@ class PlaceFetcher:
         psycopg2.extras.register_hstore(self.conn.conn, oid=hstore_oid,
                                         array_oid=hstore_array_oid)
 
         psycopg2.extras.register_hstore(self.conn.conn, oid=hstore_oid,
                                         array_oid=hstore_array_oid)
 
-    def close(self):
+    def close(self) -> None:
         """ Close the underlying asynchronous connection.
         """
         if self.conn:
         """ Close the underlying asynchronous connection.
         """
         if self.conn:
@@ -45,44 +49,46 @@ class PlaceFetcher:
             self.conn = None
 
 
             self.conn = None
 
 
-    def fetch_next_batch(self, cur, runner):
+    def fetch_next_batch(self, cur: Cursor, runner: runners.Runner) -> bool:
         """ Send a request for the next batch of places.
             If details for the places are required, they will be fetched
             asynchronously.
 
             Returns true if there is still data available.
         """
         """ Send a request for the next batch of places.
             If details for the places are required, they will be fetched
             asynchronously.
 
             Returns true if there is still data available.
         """
-        ids = cur.fetchmany(100)
+        ids = cast(Optional[DictCursorResults], cur.fetchmany(100))
 
         if not ids:
             self.current_ids = None
             return False
 
 
         if not ids:
             self.current_ids = None
             return False
 
-        if hasattr(runner, 'get_place_details'):
-            runner.get_place_details(self.conn, ids)
-            self.current_ids = []
-        else:
-            self.current_ids = ids
+        assert self.conn is not None
+        self.current_ids = runner.get_place_details(self.conn, ids)
 
         return True
 
 
         return True
 
-    def get_batch(self):
+    def get_batch(self) -> DictCursorResults:
         """ Get the next batch of data, previously requested with
             `fetch_next_batch`.
         """
         """ Get the next batch of data, previously requested with
             `fetch_next_batch`.
         """
+        assert self.conn is not None
+        assert self.conn.cursor is not None
+
         if self.current_ids is not None and not self.current_ids:
             tstart = time.time()
             self.conn.wait()
             self.wait_time += time.time() - tstart
         if self.current_ids is not None and not self.current_ids:
             tstart = time.time()
             self.conn.wait()
             self.wait_time += time.time() - tstart
-            self.current_ids = self.conn.cursor.fetchall()
+            self.current_ids = cast(Optional[DictCursorResults],
+                                    self.conn.cursor.fetchall())
 
 
-        return self.current_ids
+        return self.current_ids if self.current_ids is not None else []
 
 
-    def __enter__(self):
+    def __enter__(self) -> 'PlaceFetcher':
         return self
 
 
         return self
 
 
-    def __exit__(self, exc_type, exc_value, traceback):
+    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
+        assert self.conn is not None
         self.conn.wait()
         self.close()
 
         self.conn.wait()
         self.close()
 
@@ -91,13 +97,13 @@ class Indexer:
     """ Main indexing routine.
     """
 
     """ Main indexing routine.
     """
 
-    def __init__(self, dsn, tokenizer, num_threads):
+    def __init__(self, dsn: str, tokenizer: AbstractTokenizer, num_threads: int):
         self.dsn = dsn
         self.tokenizer = tokenizer
         self.num_threads = num_threads
 
 
         self.dsn = dsn
         self.tokenizer = tokenizer
         self.num_threads = num_threads
 
 
-    def has_pending(self):
+    def has_pending(self) -> bool:
         """ Check if any data still needs indexing.
             This function must only be used after the import has finished.
             Otherwise it will be very expensive.
         """ Check if any data still needs indexing.
             This function must only be used after the import has finished.
             Otherwise it will be very expensive.
@@ -108,7 +114,7 @@ class Indexer:
                 return cur.rowcount > 0
 
 
                 return cur.rowcount > 0
 
 
-    def index_full(self, analyse=True):
+    def index_full(self, analyse: bool = True) -> None:
         """ Index the complete database. This will first index boundaries
             followed by all other objects. When `analyse` is True, then the
             database will be analysed at the appropriate places to
         """ Index the complete database. This will first index boundaries
             followed by all other objects. When `analyse` is True, then the
             database will be analysed at the appropriate places to
@@ -117,7 +123,7 @@ class Indexer:
         with connect(self.dsn) as conn:
             conn.autocommit = True
 
         with connect(self.dsn) as conn:
             conn.autocommit = True
 
-            def _analyze():
+            def _analyze() -> None:
                 if analyse:
                     with conn.cursor() as cur:
                         cur.execute('ANALYZE')
                 if analyse:
                     with conn.cursor() as cur:
                         cur.execute('ANALYZE')
@@ -138,7 +144,7 @@ class Indexer:
             _analyze()
 
 
             _analyze()
 
 
-    def index_boundaries(self, minrank, maxrank):
+    def index_boundaries(self, minrank: int, maxrank: int) -> None:
         """ Index only administrative boundaries within the given rank range.
         """
         LOG.warning("Starting indexing boundaries using %s threads",
         """ Index only administrative boundaries within the given rank range.
         """
         LOG.warning("Starting indexing boundaries using %s threads",
@@ -148,7 +154,7 @@ class Indexer:
             for rank in range(max(minrank, 4), min(maxrank, 26)):
                 self._index(runners.BoundaryRunner(rank, analyzer))
 
             for rank in range(max(minrank, 4), min(maxrank, 26)):
                 self._index(runners.BoundaryRunner(rank, analyzer))
 
-    def index_by_rank(self, minrank, maxrank):
+    def index_by_rank(self, minrank: int, maxrank: int) -> None:
         """ Index all entries of placex in the given rank range (inclusive)
             in order of their address rank.
 
         """ Index all entries of placex in the given rank range (inclusive)
             in order of their address rank.
 
@@ -168,7 +174,7 @@ class Indexer:
                 self._index(runners.InterpolationRunner(analyzer), 20)
 
 
                 self._index(runners.InterpolationRunner(analyzer), 20)
 
 
-    def index_postcodes(self):
+    def index_postcodes(self) -> None:
         """Index the entries ofthe location_postcode table.
         """
         LOG.warning("Starting indexing postcodes using %s threads", self.num_threads)
         """Index the entries ofthe location_postcode table.
         """
         LOG.warning("Starting indexing postcodes using %s threads", self.num_threads)
@@ -176,7 +182,7 @@ class Indexer:
         self._index(runners.PostcodeRunner(), 20)
 
 
         self._index(runners.PostcodeRunner(), 20)
 
 
-    def update_status_table(self):
+    def update_status_table(self) -> None:
         """ Update the status in the status table to 'indexed'.
         """
         with connect(self.dsn) as conn:
         """ Update the status in the status table to 'indexed'.
         """
         with connect(self.dsn) as conn:
@@ -185,7 +191,7 @@ class Indexer:
 
             conn.commit()
 
 
             conn.commit()
 
-    def _index(self, runner, batch=1):
+    def _index(self, runner: runners.Runner, batch: int = 1) -> None:
         """ Index a single rank or table. `runner` describes the SQL to use
             for indexing. `batch` describes the number of objects that
             should be processed with a single SQL statement
         """ Index a single rank or table. `runner` describes the SQL to use
             for indexing. `batch` describes the number of objects that
             should be processed with a single SQL statement
index b758e10d52abce50da267df4953803dbf67e36eb..bc1d68a3c1c50a6f61ae93c313f9782f26d9135b 100644 (file)
@@ -22,7 +22,7 @@ class ProgressLogger:
         should be reported.
     """
 
         should be reported.
     """
 
-    def __init__(self, name, total, log_interval=1):
+    def __init__(self, name: str, total: int, log_interval: int = 1) -> None:
         self.name = name
         self.total_places = total
         self.done_places = 0
         self.name = name
         self.total_places = total
         self.done_places = 0
@@ -30,7 +30,7 @@ class ProgressLogger:
         self.log_interval = log_interval
         self.next_info = INITIAL_PROGRESS if LOG.isEnabledFor(logging.WARNING) else total + 1
 
         self.log_interval = log_interval
         self.next_info = INITIAL_PROGRESS if LOG.isEnabledFor(logging.WARNING) else total + 1
 
-    def add(self, num=1):
+    def add(self, num: int = 1) -> None:
         """ Mark `num` places as processed. Print a log message if the
             logging is at least info and the log interval has passed.
         """
         """ Mark `num` places as processed. Print a log message if the
             logging is at least info and the log interval has passed.
         """
@@ -55,14 +55,14 @@ class ProgressLogger:
 
         self.next_info += int(places_per_sec) * self.log_interval
 
 
         self.next_info += int(places_per_sec) * self.log_interval
 
-    def done(self):
+    def done(self) -> None:
         """ Print final statistics about the progress.
         """
         rank_end_time = datetime.now()
 
         if rank_end_time == self.rank_start_time:
         """ Print final statistics about the progress.
         """
         rank_end_time = datetime.now()
 
         if rank_end_time == self.rank_start_time:
-            diff_seconds = 0
-            places_per_sec = self.done_places
+            diff_seconds = 0.0
+            places_per_sec = float(self.done_places)
         else:
             diff_seconds = (rank_end_time - self.rank_start_time).total_seconds()
             places_per_sec = self.done_places / diff_seconds
         else:
             diff_seconds = (rank_end_time - self.rank_start_time).total_seconds()
             places_per_sec = self.done_places / diff_seconds
index c8495ee4df115ef96cf7971cff7c7eb1318948a0..973c6ea98c0cf4741a61ebd72c91770530ca056c 100644 (file)
@@ -8,35 +8,49 @@
 Mix-ins that provide the actual commands for the indexer for various indexing
 tasks.
 """
 Mix-ins that provide the actual commands for the indexer for various indexing
 tasks.
 """
+from typing import Any, List
 import functools
 
 import functools
 
+from typing_extensions import Protocol
 from psycopg2 import sql as pysql
 import psycopg2.extras
 
 from nominatim.data.place_info import PlaceInfo
 from psycopg2 import sql as pysql
 import psycopg2.extras
 
 from nominatim.data.place_info import PlaceInfo
+from nominatim.tokenizer.base import AbstractAnalyzer
+from nominatim.db.async_connection import DBConnection
+from nominatim.typing import Query, DictCursorResult, DictCursorResults
 
 # pylint: disable=C0111
 
 
 # pylint: disable=C0111
 
-def _mk_valuelist(template, num):
+def _mk_valuelist(template: str, num: int) -> pysql.Composed:
     return pysql.SQL(',').join([pysql.SQL(template)] * num)
 
     return pysql.SQL(',').join([pysql.SQL(template)] * num)
 
-def _analyze_place(place, analyzer):
+def _analyze_place(place: DictCursorResult, analyzer: AbstractAnalyzer) -> psycopg2.extras.Json:
     return psycopg2.extras.Json(analyzer.process_place(PlaceInfo(place)))
 
     return psycopg2.extras.Json(analyzer.process_place(PlaceInfo(place)))
 
+
+class Runner(Protocol):
+    def name(self) -> str: ...
+    def sql_count_objects(self) -> Query: ...
+    def sql_get_objects(self) -> Query: ...
+    def get_place_details(self, worker: DBConnection,
+                          ids: DictCursorResults) -> DictCursorResults: ...
+    def index_places(self, worker: DBConnection, places: DictCursorResults) -> None: ...
+
+
 class AbstractPlacexRunner:
     """ Returns SQL commands for indexing of the placex table.
     """
     SELECT_SQL = pysql.SQL('SELECT place_id FROM placex ')
     UPDATE_LINE = "(%s, %s::hstore, %s::hstore, %s::int, %s::jsonb)"
 
 class AbstractPlacexRunner:
     """ Returns SQL commands for indexing of the placex table.
     """
     SELECT_SQL = pysql.SQL('SELECT place_id FROM placex ')
     UPDATE_LINE = "(%s, %s::hstore, %s::hstore, %s::int, %s::jsonb)"
 
-    def __init__(self, rank, analyzer):
+    def __init__(self, rank: int, analyzer: AbstractAnalyzer) -> None:
         self.rank = rank
         self.analyzer = analyzer
 
 
         self.rank = rank
         self.analyzer = analyzer
 
 
-    @staticmethod
     @functools.lru_cache(maxsize=1)
     @functools.lru_cache(maxsize=1)
-    def _index_sql(num_places):
+    def _index_sql(self, num_places: int) -> pysql.Composed:
         return pysql.SQL(
             """ UPDATE placex
                 SET indexed_status = 0, address = v.addr, token_info = v.ti,
         return pysql.SQL(
             """ UPDATE placex
                 SET indexed_status = 0, address = v.addr, token_info = v.ti,
@@ -46,16 +60,17 @@ class AbstractPlacexRunner:
             """).format(_mk_valuelist(AbstractPlacexRunner.UPDATE_LINE, num_places))
 
 
             """).format(_mk_valuelist(AbstractPlacexRunner.UPDATE_LINE, num_places))
 
 
-    @staticmethod
-    def get_place_details(worker, ids):
+    def get_place_details(self, worker: DBConnection, ids: DictCursorResults) -> DictCursorResults:
         worker.perform("""SELECT place_id, extra.*
                           FROM placex, LATERAL placex_indexing_prepare(placex) as extra
                           WHERE place_id IN %s""",
                        (tuple((p[0] for p in ids)), ))
 
         worker.perform("""SELECT place_id, extra.*
                           FROM placex, LATERAL placex_indexing_prepare(placex) as extra
                           WHERE place_id IN %s""",
                        (tuple((p[0] for p in ids)), ))
 
+        return []
+
 
 
-    def index_places(self, worker, places):
-        values = []
+    def index_places(self, worker: DBConnection, places: DictCursorResults) -> None:
+        values: List[Any] = []
         for place in places:
             for field in ('place_id', 'name', 'address', 'linked_place_id'):
                 values.append(place[field])
         for place in places:
             for field in ('place_id', 'name', 'address', 'linked_place_id'):
                 values.append(place[field])
@@ -68,15 +83,15 @@ class RankRunner(AbstractPlacexRunner):
     """ Returns SQL commands for indexing one rank within the placex table.
     """
 
     """ Returns SQL commands for indexing one rank within the placex table.
     """
 
-    def name(self):
+    def name(self) -> str:
         return f"rank {self.rank}"
 
         return f"rank {self.rank}"
 
-    def sql_count_objects(self):
+    def sql_count_objects(self) -> pysql.Composed:
         return pysql.SQL("""SELECT count(*) FROM placex
                             WHERE rank_address = {} and indexed_status > 0
                          """).format(pysql.Literal(self.rank))
 
         return pysql.SQL("""SELECT count(*) FROM placex
                             WHERE rank_address = {} and indexed_status > 0
                          """).format(pysql.Literal(self.rank))
 
-    def sql_get_objects(self):
+    def sql_get_objects(self) -> pysql.Composed:
         return self.SELECT_SQL + pysql.SQL(
             """WHERE indexed_status > 0 and rank_address = {}
                ORDER BY geometry_sector
         return self.SELECT_SQL + pysql.SQL(
             """WHERE indexed_status > 0 and rank_address = {}
                ORDER BY geometry_sector
@@ -88,17 +103,17 @@ class BoundaryRunner(AbstractPlacexRunner):
         of a certain rank.
     """
 
         of a certain rank.
     """
 
-    def name(self):
+    def name(self) -> str:
         return f"boundaries rank {self.rank}"
 
         return f"boundaries rank {self.rank}"
 
-    def sql_count_objects(self):
+    def sql_count_objects(self) -> pysql.Composed:
         return pysql.SQL("""SELECT count(*) FROM placex
                             WHERE indexed_status > 0
                               AND rank_search = {}
                               AND class = 'boundary' and type = 'administrative'
                          """).format(pysql.Literal(self.rank))
 
         return pysql.SQL("""SELECT count(*) FROM placex
                             WHERE indexed_status > 0
                               AND rank_search = {}
                               AND class = 'boundary' and type = 'administrative'
                          """).format(pysql.Literal(self.rank))
 
-    def sql_get_objects(self):
+    def sql_get_objects(self) -> pysql.Composed:
         return self.SELECT_SQL + pysql.SQL(
             """WHERE indexed_status > 0 and rank_search = {}
                      and class = 'boundary' and type = 'administrative'
         return self.SELECT_SQL + pysql.SQL(
             """WHERE indexed_status > 0 and rank_search = {}
                      and class = 'boundary' and type = 'administrative'
@@ -111,37 +126,33 @@ class InterpolationRunner:
         location_property_osmline.
     """
 
         location_property_osmline.
     """
 
-    def __init__(self, analyzer):
+    def __init__(self, analyzer: AbstractAnalyzer) -> None:
         self.analyzer = analyzer
 
 
         self.analyzer = analyzer
 
 
-    @staticmethod
-    def name():
+    def name(self) -> str:
         return "interpolation lines (location_property_osmline)"
 
         return "interpolation lines (location_property_osmline)"
 
-    @staticmethod
-    def sql_count_objects():
+    def sql_count_objects(self) -> str:
         return """SELECT count(*) FROM location_property_osmline
                   WHERE indexed_status > 0"""
 
         return """SELECT count(*) FROM location_property_osmline
                   WHERE indexed_status > 0"""
 
-    @staticmethod
-    def sql_get_objects():
+    def sql_get_objects(self) -> str:
         return """SELECT place_id
                   FROM location_property_osmline
                   WHERE indexed_status > 0
                   ORDER BY geometry_sector"""
 
 
         return """SELECT place_id
                   FROM location_property_osmline
                   WHERE indexed_status > 0
                   ORDER BY geometry_sector"""
 
 
-    @staticmethod
-    def get_place_details(worker, ids):
+    def get_place_details(self, worker: DBConnection, ids: DictCursorResults) -> DictCursorResults:
         worker.perform("""SELECT place_id, get_interpolation_address(address, osm_id) as address
                           FROM location_property_osmline WHERE place_id IN %s""",
                        (tuple((p[0] for p in ids)), ))
         worker.perform("""SELECT place_id, get_interpolation_address(address, osm_id) as address
                           FROM location_property_osmline WHERE place_id IN %s""",
                        (tuple((p[0] for p in ids)), ))
+        return []
 
 
 
 
-    @staticmethod
     @functools.lru_cache(maxsize=1)
     @functools.lru_cache(maxsize=1)
-    def _index_sql(num_places):
+    def _index_sql(self, num_places: int) -> pysql.Composed:
         return pysql.SQL("""UPDATE location_property_osmline
                             SET indexed_status = 0, address = v.addr, token_info = v.ti
                             FROM (VALUES {}) as v(id, addr, ti)
         return pysql.SQL("""UPDATE location_property_osmline
                             SET indexed_status = 0, address = v.addr, token_info = v.ti
                             FROM (VALUES {}) as v(id, addr, ti)
@@ -149,8 +160,8 @@ class InterpolationRunner:
                          """).format(_mk_valuelist("(%s, %s::hstore, %s::jsonb)", num_places))
 
 
                          """).format(_mk_valuelist("(%s, %s::hstore, %s::jsonb)", num_places))
 
 
-    def index_places(self, worker, places):
-        values = []
+    def index_places(self, worker: DBConnection, places: DictCursorResults) -> None:
+        values: List[Any] = []
         for place in places:
             values.extend((place[x] for x in ('place_id', 'address')))
             values.append(_analyze_place(place, self.analyzer))
         for place in places:
             values.extend((place[x] for x in ('place_id', 'address')))
             values.append(_analyze_place(place, self.analyzer))
@@ -159,26 +170,28 @@ class InterpolationRunner:
 
 
 
 
 
 
-class PostcodeRunner:
+class PostcodeRunner(Runner):
     """ Provides the SQL commands for indexing the location_postcode table.
     """
 
     """ Provides the SQL commands for indexing the location_postcode table.
     """
 
-    @staticmethod
-    def name():
+    def name(self) -> str:
         return "postcodes (location_postcode)"
 
         return "postcodes (location_postcode)"
 
-    @staticmethod
-    def sql_count_objects():
+
+    def sql_count_objects(self) -> str:
         return 'SELECT count(*) FROM location_postcode WHERE indexed_status > 0'
 
         return 'SELECT count(*) FROM location_postcode WHERE indexed_status > 0'
 
-    @staticmethod
-    def sql_get_objects():
+
+    def sql_get_objects(self) -> str:
         return """SELECT place_id FROM location_postcode
                   WHERE indexed_status > 0
                   ORDER BY country_code, postcode"""
 
         return """SELECT place_id FROM location_postcode
                   WHERE indexed_status > 0
                   ORDER BY country_code, postcode"""
 
-    @staticmethod
-    def index_places(worker, ids):
+
+    def get_place_details(self, worker: DBConnection, ids: DictCursorResults) -> DictCursorResults:
+        return ids
+
+    def index_places(self, worker: DBConnection, places: DictCursorResults) -> None:
         worker.perform(pysql.SQL("""UPDATE location_postcode SET indexed_status = 0
                                     WHERE place_id IN ({})""")
         worker.perform(pysql.SQL("""UPDATE location_postcode SET indexed_status = 0
                                     WHERE place_id IN ({})""")
-                       .format(pysql.SQL(',').join((pysql.Literal(i[0]) for i in ids))))
+                       .format(pysql.SQL(',').join((pysql.Literal(i[0]) for i in places))))
index 6d7549899bda1ae674c4ee42a0a717ea10eb0445..36bde8347e7dd4fc8d94e9b9ea422f5ec1a77810 100644 (file)
@@ -9,14 +9,15 @@ Type definitions for typing annotations.
 
 Complex type definitions are moved here, to keep the source files readable.
 """
 
 Complex type definitions are moved here, to keep the source files readable.
 """
-from typing import Union, Mapping, TypeVar, TYPE_CHECKING
+from typing import Any, Union, Mapping, TypeVar, Sequence, TYPE_CHECKING
 
 # Generics varaible names do not confirm to naming styles, ignore globally here.
 
 # Generics varaible names do not confirm to naming styles, ignore globally here.
-# pylint: disable=invalid-name
+# pylint: disable=invalid-name,abstract-method,multiple-statements,missing-class-docstring
 
 if TYPE_CHECKING:
     import psycopg2.sql
     import psycopg2.extensions
 
 if TYPE_CHECKING:
     import psycopg2.sql
     import psycopg2.extensions
+    import psycopg2.extras
     import os
 
 StrPath = Union[str, 'os.PathLike[str]']
     import os
 
 StrPath = Union[str, 'os.PathLike[str]']
@@ -26,4 +27,12 @@ SysEnv = Mapping[str, str]
 # psycopg2-related types
 
 Query = Union[str, bytes, 'psycopg2.sql.Composable']
 # psycopg2-related types
 
 Query = Union[str, bytes, 'psycopg2.sql.Composable']
+
+T_ResultKey = TypeVar('T_ResultKey', int, str)
+
+class DictCursorResult(Mapping[str, Any]):
+    def __getitem__(self, x: Union[int, str]) -> Any: ...
+
+DictCursorResults = Sequence[DictCursorResult]
+
 T_cursor = TypeVar('T_cursor', bound='psycopg2.extensions.cursor')
 T_cursor = TypeVar('T_cursor', bound='psycopg2.extensions.cursor')