]> git.openstreetmap.org Git - nominatim.git/blob - nominatim/clicmd/export.py
use correct SQLAlchemy pool for asynchronous connections
[nominatim.git] / nominatim / clicmd / export.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2023 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Implementation of the 'export' subcommand.
9 """
10 from typing import Optional, List, cast
11 import logging
12 import argparse
13 import asyncio
14 import csv
15 import sys
16
17 import sqlalchemy as sa
18
19 from nominatim.clicmd.args import NominatimArgs
20 import nominatim.api as napi
21 from nominatim.api.results import create_from_placex_row, ReverseResult, add_result_details
22 from nominatim.api.types import LookupDetails
23 from nominatim.errors import UsageError
24
25 # Do not repeat documentation of subcommand classes.
26 # pylint: disable=C0111
27 # Using non-top-level imports to avoid eventually unused imports.
28 # pylint: disable=E0012,C0415
29 # Needed for SQLAlchemy
30 # pylint: disable=singleton-comparison
31
32 LOG = logging.getLogger()
33
34 RANK_RANGE_MAP = {
35   'country': (4, 4),
36   'state': (5, 9),
37   'county': (10, 12),
38   'city': (13, 16),
39   'suburb': (17, 21),
40   'street': (26, 26),
41   'path': (27, 27)
42 }
43
44 RANK_TO_OUTPUT_MAP = {
45     4: 'country',
46     5: 'state', 6: 'state', 7: 'state', 8: 'state', 9: 'state',
47     10: 'county', 11: 'county', 12: 'county',
48     13: 'city', 14: 'city', 15: 'city', 16: 'city',
49     17: 'suburb', 18: 'suburb', 19: 'suburb', 20: 'suburb', 21: 'suburb',
50     26: 'street', 27: 'path'}
51
52 class QueryExport:
53     """\
54     Export places as CSV file from the database.
55
56
57     """
58
59     def add_args(self, parser: argparse.ArgumentParser) -> None:
60         group = parser.add_argument_group('Output arguments')
61         group.add_argument('--output-type', default='street',
62                            choices=('country', 'state', 'county',
63                                     'city', 'suburb', 'street', 'path'),
64                            help='Type of places to output (default: street)')
65         group.add_argument('--output-format',
66                            default='street;suburb;city;county;state;country',
67                            help=("Semicolon-separated list of address types "
68                                  "(see --output-type). Additionally accepts:"
69                                  "placeid,postcode"))
70         group.add_argument('--language',
71                            help=("Preferred language for output "
72                                  "(use local name, if omitted)"))
73         group = parser.add_argument_group('Filter arguments')
74         group.add_argument('--restrict-to-country', metavar='COUNTRY_CODE',
75                            help='Export only objects within country')
76         group.add_argument('--restrict-to-osm-node', metavar='ID', type=int,
77                            dest='node',
78                            help='Export only children of this OSM node')
79         group.add_argument('--restrict-to-osm-way', metavar='ID', type=int,
80                            dest='way',
81                            help='Export only children of this OSM way')
82         group.add_argument('--restrict-to-osm-relation', metavar='ID', type=int,
83                            dest='relation',
84                            help='Export only children of this OSM relation')
85
86
87     def run(self, args: NominatimArgs) -> int:
88         return asyncio.run(export(args))
89
90
91 async def export(args: NominatimArgs) -> int:
92     """ The actual export as a asynchronous function.
93     """
94
95     api = napi.NominatimAPIAsync(args.project_dir)
96
97     try:
98         output_range = RANK_RANGE_MAP[args.output_type]
99
100         writer = init_csv_writer(args.output_format)
101
102         async with api.begin() as conn, api.begin() as detail_conn:
103             t = conn.t.placex
104
105             sql = sa.select(t.c.place_id, t.c.parent_place_id,
106                         t.c.osm_type, t.c.osm_id, t.c.name,
107                         t.c.class_, t.c.type, t.c.admin_level,
108                         t.c.address, t.c.extratags,
109                         t.c.housenumber, t.c.postcode, t.c.country_code,
110                         t.c.importance, t.c.wikipedia, t.c.indexed_date,
111                         t.c.rank_address, t.c.rank_search,
112                         t.c.centroid)\
113                      .where(t.c.linked_place_id == None)\
114                      .where(t.c.rank_address.between(*output_range))
115
116             parent_place_id = await get_parent_id(conn, args.node, args.way, args.relation)
117             if parent_place_id:
118                 taddr = conn.t.addressline
119
120                 sql = sql.join(taddr, taddr.c.place_id == t.c.place_id)\
121                          .where(taddr.c.address_place_id == parent_place_id)\
122                          .where(taddr.c.isaddress)
123
124             if args.restrict_to_country:
125                 sql = sql.where(t.c.country_code == args.restrict_to_country.lower())
126
127             results = []
128             for row in await conn.execute(sql):
129                 result = create_from_placex_row(row, ReverseResult)
130                 if result is not None:
131                     results.append(result)
132
133                 if len(results) == 1000:
134                     await dump_results(detail_conn, results, writer, args.language)
135                     results = []
136
137             if results:
138                 await dump_results(detail_conn, results, writer, args.language)
139     finally:
140         await api.close()
141
142     return 0
143
144
145 def init_csv_writer(output_format: str) -> 'csv.DictWriter[str]':
146     fields = output_format.split(';')
147     writer = csv.DictWriter(sys.stdout, fieldnames=fields, extrasaction='ignore')
148     writer.writeheader()
149
150     return writer
151
152
153 async def dump_results(conn: napi.SearchConnection,
154                        results: List[ReverseResult],
155                        writer: 'csv.DictWriter[str]',
156                        lang: Optional[str]) -> None:
157     locale = napi.Locales([lang] if lang else None)
158     await add_result_details(conn, results,
159                              LookupDetails(address_details=True, locales=locale))
160
161
162     for result in results:
163         data = {'placeid': result.place_id,
164                 'postcode': result.postcode}
165
166         for line in (result.address_rows or []):
167             if line.isaddress and line.local_name:
168                 if line.category[1] == 'postcode':
169                     data['postcode'] = line.local_name
170                 elif line.rank_address in RANK_TO_OUTPUT_MAP:
171                     data[RANK_TO_OUTPUT_MAP[line.rank_address]] = line.local_name
172
173         writer.writerow(data)
174
175
176 async def get_parent_id(conn: napi.SearchConnection, node_id: Optional[int],
177                         way_id: Optional[int],
178                         relation_id: Optional[int]) -> Optional[int]:
179     """ Get the place ID for the given OSM object.
180     """
181     if node_id is not None:
182         osm_type, osm_id = 'N', node_id
183     elif way_id is not None:
184         osm_type, osm_id = 'W', way_id
185     elif relation_id is not None:
186         osm_type, osm_id = 'R', relation_id
187     else:
188         return None
189
190     t = conn.t.placex
191     sql = sa.select(t.c.place_id).limit(1)\
192             .where(t.c.osm_type == osm_type)\
193             .where(t.c.osm_id == osm_id)\
194             .where(t.c.rank_address > 0)\
195             .order_by(t.c.rank_address)
196
197     for result in await conn.execute(sql):
198         return cast(int, result[0])
199
200     raise UsageError(f'Cannot find a place {osm_type}{osm_id}.')