"""
Main work horse for indexing (computing addresses) the database.
"""
-from typing import Optional, Any, cast
+from typing import cast, List, Any
import logging
import time
-import psycopg2.extras
+import psycopg
-from ..typing import DictCursorResults
-from ..db.async_connection import DBConnection, WorkerPool
-from ..db.connection import connect, Connection, Cursor
+from ..db.connection import connect, execute_scalar
+from ..db.query_pool import QueryPool
from ..tokenizer.base import AbstractTokenizer
from .progress import ProgressLogger
from . import runners
LOG = logging.getLogger()
-
-class PlaceFetcher:
- """ Asynchronous connection that fetches place details for processing.
- """
- def __init__(self, dsn: str, setup_conn: Connection) -> None:
- self.wait_time = 0.0
- self.current_ids: Optional[DictCursorResults] = None
- self.conn: Optional[DBConnection] = DBConnection(dsn,
- cursor_factory=psycopg2.extras.DictCursor)
-
- with setup_conn.cursor() as cur:
- # need to fetch those manually because register_hstore cannot
- # fetch them on an asynchronous connection below.
- hstore_oid = cur.scalar("SELECT 'hstore'::regtype::oid")
- hstore_array_oid = cur.scalar("SELECT 'hstore[]'::regtype::oid")
-
- psycopg2.extras.register_hstore(self.conn.conn, oid=hstore_oid,
- array_oid=hstore_array_oid)
-
- def close(self) -> None:
- """ Close the underlying asynchronous connection.
- """
- if self.conn:
- self.conn.close()
- self.conn = None
-
-
- def fetch_next_batch(self, cur: Cursor, runner: runners.Runner) -> bool:
- """ Send a request for the next batch of places.
- If details for the places are required, they will be fetched
- asynchronously.
-
- Returns true if there is still data available.
- """
- ids = cast(Optional[DictCursorResults], cur.fetchmany(100))
-
- if not ids:
- self.current_ids = None
- return False
-
- assert self.conn is not None
- self.current_ids = runner.get_place_details(self.conn, ids)
-
- return True
-
- def get_batch(self) -> DictCursorResults:
- """ Get the next batch of data, previously requested with
- `fetch_next_batch`.
- """
- assert self.conn is not None
- assert self.conn.cursor is not None
-
- if self.current_ids is not None and not self.current_ids:
- tstart = time.time()
- self.conn.wait()
- self.wait_time += time.time() - tstart
- self.current_ids = cast(Optional[DictCursorResults],
- self.conn.cursor.fetchall())
-
- return self.current_ids if self.current_ids is not None else []
-
- def __enter__(self) -> 'PlaceFetcher':
- return self
-
-
- def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
- assert self.conn is not None
- self.conn.wait()
- self.close()
-
-
class Indexer:
""" Main indexing routine.
"""
return cur.rowcount > 0
- def index_full(self, analyse: bool = True) -> None:
+ async def index_full(self, analyse: bool = True) -> None:
""" Index the complete database. This will first index boundaries
followed by all other objects. When `analyse` is True, then the
database will be analysed at the appropriate places to
with conn.cursor() as cur:
cur.execute('ANALYZE')
- if self.index_by_rank(0, 4) > 0:
- _analyze()
+ while True:
+ if await self.index_by_rank(0, 4) > 0:
+ _analyze()
+
+ if await self.index_boundaries(0, 30) > 100:
+ _analyze()
- if self.index_boundaries(0, 30) > 100:
- _analyze()
+ if await self.index_by_rank(5, 25) > 100:
+ _analyze()
- if self.index_by_rank(5, 25) > 100:
- _analyze()
+ if await self.index_by_rank(26, 30) > 1000:
+ _analyze()
- if self.index_by_rank(26, 30) > 1000:
- _analyze()
+ if await self.index_postcodes() > 100:
+ _analyze()
- if self.index_postcodes() > 100:
- _analyze()
+ if not self.has_pending():
+ break
- def index_boundaries(self, minrank: int, maxrank: int) -> int:
+ async def index_boundaries(self, minrank: int, maxrank: int) -> int:
""" Index only administrative boundaries within the given rank range.
"""
total = 0
with self.tokenizer.name_analyzer() as analyzer:
for rank in range(max(minrank, 4), min(maxrank, 26)):
- total += self._index(runners.BoundaryRunner(rank, analyzer))
+ total += await self._index(runners.BoundaryRunner(rank, analyzer))
return total
- def index_by_rank(self, minrank: int, maxrank: int) -> int:
+ async def index_by_rank(self, minrank: int, maxrank: int) -> int:
""" Index all entries of placex in the given rank range (inclusive)
in order of their address rank.
with self.tokenizer.name_analyzer() as analyzer:
for rank in range(max(1, minrank), maxrank + 1):
- total += self._index(runners.RankRunner(rank, analyzer), 20 if rank == 30 else 1)
+ if rank >= 30:
+ batch = 20
+ elif rank >= 26:
+ batch = 5
+ else:
+ batch = 1
+ total += await self._index(runners.RankRunner(rank, analyzer), batch)
if maxrank == 30:
- total += self._index(runners.RankRunner(0, analyzer))
- total += self._index(runners.InterpolationRunner(analyzer), 20)
+ total += await self._index(runners.RankRunner(0, analyzer))
+ total += await self._index(runners.InterpolationRunner(analyzer), 20)
return total
- def index_postcodes(self) -> int:
+ async def index_postcodes(self) -> int:
"""Index the entries of the location_postcode table.
"""
LOG.warning("Starting indexing postcodes using %s threads", self.num_threads)
- return self._index(runners.PostcodeRunner(), 20)
+ return await self._index(runners.PostcodeRunner(), 20)
def update_status_table(self) -> None:
conn.commit()
- def _index(self, runner: runners.Runner, batch: int = 1) -> int:
+ async def _index(self, runner: runners.Runner, batch: int = 1) -> int:
""" Index a single rank or table. `runner` describes the SQL to use
for indexing. `batch` describes the number of objects that
should be processed with a single SQL statement
"""
LOG.warning("Starting %s (using batch size %s)", runner.name(), batch)
- with connect(self.dsn) as conn:
- psycopg2.extras.register_hstore(conn)
- with conn.cursor() as cur:
- total_tuples = cur.scalar(runner.sql_count_objects())
- LOG.debug("Total number of rows: %i", total_tuples)
-
- conn.commit()
+ total_tuples = self._prepare_indexing(runner)
- progress = ProgressLogger(runner.name(), total_tuples)
+ progress = ProgressLogger(runner.name(), total_tuples)
- if total_tuples > 0:
- with conn.cursor(name='places') as cur:
- cur.execute(runner.sql_get_objects())
+ if total_tuples > 0:
+ async with await psycopg.AsyncConnection.connect(
+ self.dsn, row_factory=psycopg.rows.dict_row) as aconn,\
+ QueryPool(self.dsn, self.num_threads, autocommit=True) as pool:
+ fetcher_time = 0.0
+ tstart = time.time()
+ async with aconn.cursor(name='places') as cur:
+ query = runner.index_places_query(batch)
+ params: List[Any] = []
+ num_places = 0
+ async for place in cur.stream(runner.sql_get_objects()):
+ fetcher_time += time.time() - tstart
- with PlaceFetcher(self.dsn, conn) as fetcher:
- with WorkerPool(self.dsn, self.num_threads) as pool:
- has_more = fetcher.fetch_next_batch(cur, runner)
- while has_more:
- places = fetcher.get_batch()
+ params.extend(runner.index_places_params(place))
+ num_places += 1
- # asynchronously get the next batch
- has_more = fetcher.fetch_next_batch(cur, runner)
+ if num_places >= batch:
+ LOG.debug("Processing places: %s", str(params))
+ await pool.put_query(query, params)
+ progress.add(num_places)
+ params = []
+ num_places = 0
- # And insert the current batch
- for idx in range(0, len(places), batch):
- part = places[idx:idx + batch]
- LOG.debug("Processing places: %s", str(part))
- runner.index_places(pool.next_free_worker(), part)
- progress.add(len(part))
+ tstart = time.time()
- LOG.info("Wait time: fetcher: %.2fs, pool: %.2fs",
- fetcher.wait_time, pool.wait_time)
+ if num_places > 0:
+ await pool.put_query(runner.index_places_query(num_places), params)
- conn.commit()
+ LOG.info("Wait time: fetcher: %.2fs, pool: %.2fs",
+ fetcher_time, pool.wait_time)
return progress.done()
+
+
+ def _prepare_indexing(self, runner: runners.Runner) -> int:
+ with connect(self.dsn) as conn:
+ hstore_info = psycopg.types.TypeInfo.fetch(conn, "hstore")
+ if hstore_info is None:
+ raise RuntimeError('Hstore extension is requested but not installed.')
+ psycopg.types.hstore.register_hstore(hstore_info)
+
+ total_tuples = execute_scalar(conn, runner.sql_count_objects())
+ LOG.debug("Total number of rows: %i", total_tuples)
+ return cast(int, total_tuples)