]> git.openstreetmap.org Git - nominatim.git/blobdiff - nominatim/tools/convert_sqlite.py
Merge pull request #3383 from lonvia/window-searches
[nominatim.git] / nominatim / tools / convert_sqlite.py
index d9e39ba37402b7a9dcc7455fa1665ce978b82eb9..1e7beae57645c172e15f72f999ce178a0ea4beeb 100644 (file)
@@ -7,13 +7,14 @@
 """
 Exporting a Nominatim database to SQlite.
 """
-from typing import Set
+from typing import Set, Any
+import datetime as dt
 import logging
 from pathlib import Path
 
 import sqlalchemy as sa
 
-from nominatim.typing import SaSelect
+from nominatim.typing import SaSelect, SaRow
 from nominatim.db.sqlalchemy_types import Geometry, IntArray
 from nominatim.api.search.query_analyzer_factory import make_query_analyzer
 import nominatim.api as napi
@@ -28,7 +29,8 @@ async def convert(project_dir: Path, outfile: Path, options: Set[str]) -> None:
 
     try:
         outapi = napi.NominatimAPIAsync(project_dir,
-                                        {'NOMINATIM_DATABASE_DSN': f"sqlite:dbname={outfile}"})
+                                        {'NOMINATIM_DATABASE_DSN': f"sqlite:dbname={outfile}",
+                                         'NOMINATIM_DATABASE_RW': '1'})
 
         try:
             async with api.begin() as src, outapi.begin() as dest:
@@ -123,12 +125,20 @@ class SqliteWriter:
     async def copy_data(self) -> None:
         """ Copy data for all registered tables.
         """
+        def _getfield(row: SaRow, key: str) -> Any:
+            value = getattr(row, key)
+            if isinstance(value, dt.datetime):
+                if value.tzinfo is not None:
+                    value = value.astimezone(dt.timezone.utc)
+            return value
+
         for table in self.dest.t.meta.sorted_tables:
             LOG.warning("Copying '%s'", table.name)
             async_result = await self.src.connection.stream(self.select_from(table.name))
 
             async for partition in async_result.partitions(10000):
-                data = [{('class_' if k == 'class' else k): getattr(r, k) for k in r._fields}
+                data = [{('class_' if k == 'class' else k): _getfield(r, k)
+                         for k in r._fields}
                         for r in partition]
                 await self.dest.execute(table.insert(), data)
 
@@ -205,15 +215,15 @@ class SqliteWriter:
     async def create_search_index(self) -> None:
         """ Create the tables and indexes needed for word lookup.
         """
+        LOG.warning("Creating reverse search table")
+        rsn = sa.Table('reverse_search_name', self.dest.t.meta,
+                       sa.Column('word', sa.Integer()),
+                       sa.Column('column', sa.Text()),
+                       sa.Column('places', IntArray))
+        await self.dest.connection.run_sync(rsn.create)
+
         tsrc = self.src.t.search_name
         for column in ('name_vector', 'nameaddress_vector'):
-            table_name = f'reverse_search_{column}'
-            LOG.warning("Creating reverse search %s", table_name)
-            rsn = sa.Table(table_name, self.dest.t.meta,
-                           sa.Column('word', sa.Integer()),
-                           sa.Column('places', IntArray))
-            await self.dest.connection.run_sync(rsn.create)
-
             sql = sa.select(sa.func.unnest(getattr(tsrc.c, column)).label('word'),
                             sa.func.ArrayAgg(tsrc.c.place_id).label('places'))\
                     .group_by('word')
@@ -224,11 +234,12 @@ class SqliteWriter:
                 for row in partition:
                     row.places.sort()
                     data.append({'word': row.word,
+                                 'column': column,
                                  'places': row.places})
                 await self.dest.execute(rsn.insert(), data)
 
-            await self.dest.connection.run_sync(
-                sa.Index(f'idx_reverse_search_{column}_word', rsn.c.word).create)
+        await self.dest.connection.run_sync(
+            sa.Index('idx_reverse_search_name_word', rsn.c.word).create)
 
 
     def select_from(self, table: str) -> SaSelect: