if args.analyse_indexing:
LOG.warning('Analysing performance of indexing function')
from ..tools import admin
- conn = connect(args.config.get_libpq_dsn())
- admin.analyse_indexing(conn, osm_id=args.osm_id, place_id=args.place_id)
- conn.close()
+ with connect(args.config.get_libpq_dsn()) as conn:
+ admin.analyse_indexing(conn, osm_id=args.osm_id, place_id=args.place_id)
return 0
def run(args):
from ..tools import freeze
- conn = connect(args.config.get_libpq_dsn())
- freeze.drop_update_tables(conn)
+ with connect(args.config.get_libpq_dsn()) as conn:
+ freeze.drop_update_tables(conn)
freeze.drop_flatnode_file(args.config.FLATNODE_FILE)
- conn.close()
return 0
if not args.no_boundaries and not args.boundaries_only \
and args.minrank == 0 and args.maxrank == 30:
- conn = connect(args.config.get_libpq_dsn())
- status.set_indexed(conn, True)
- conn.close()
+ with connect(args.config.get_libpq_dsn()) as conn:
+ status.set_indexed(conn, True)
return 0
if args.postcodes:
LOG.warning("Update postcodes centroid")
- conn = connect(args.config.get_libpq_dsn())
- refresh.update_postcodes(conn, args.sqllib_dir)
- conn.close()
+ with connect(args.config.get_libpq_dsn()) as conn:
+ refresh.update_postcodes(conn, args.sqllib_dir)
if args.word_counts:
LOG.warning('Recompute frequency of full-word search terms')
- conn = connect(args.config.get_libpq_dsn())
- refresh.recompute_word_counts(conn, args.sqllib_dir)
- conn.close()
+ with connect(args.config.get_libpq_dsn()) as conn:
+ refresh.recompute_word_counts(conn, args.sqllib_dir)
if args.address_levels:
cfg = Path(args.config.ADDRESS_LEVEL_CONFIG)
LOG.warning('Updating address levels from %s', cfg)
- conn = connect(args.config.get_libpq_dsn())
- refresh.load_address_levels_from_file(conn, cfg)
- conn.close()
+ with connect(args.config.get_libpq_dsn()) as conn:
+ refresh.load_address_levels_from_file(conn, cfg)
if args.functions:
LOG.warning('Create functions')
- conn = connect(args.config.get_libpq_dsn())
- refresh.create_functions(conn, args.config, args.sqllib_dir,
- args.diffs, args.enable_debug_statements)
- conn.close()
+ with connect(args.config.get_libpq_dsn()) as conn:
+ refresh.create_functions(conn, args.config, args.sqllib_dir,
+ args.diffs, args.enable_debug_statements)
if args.wiki_data:
run_legacy_script('setup.php', '--import-wikipedia-articles',
from ..tools import replication, refresh
LOG.warning("Initialising replication updates")
- conn = connect(args.config.get_libpq_dsn())
- replication.init_replication(conn, base_url=args.config.REPLICATION_URL)
- if args.update_functions:
- LOG.warning("Create functions")
- refresh.create_functions(conn, args.config, args.sqllib_dir,
- True, False)
- conn.close()
+ with connect(args.config.get_libpq_dsn()) as conn:
+ replication.init_replication(conn, base_url=args.config.REPLICATION_URL)
+ if args.update_functions:
+ LOG.warning("Create functions")
+ refresh.create_functions(conn, args.config, args.sqllib_dir,
+ True, False)
return 0
def _check_for_updates(args):
from ..tools import replication
- conn = connect(args.config.get_libpq_dsn())
- ret = replication.check_for_updates(conn, base_url=args.config.REPLICATION_URL)
- conn.close()
- return ret
+ with connect(args.config.get_libpq_dsn()) as conn:
+ return replication.check_for_updates(conn, base_url=args.config.REPLICATION_URL)
@staticmethod
def _report_update(batchdate, start_import, start_index):
recheck_interval = args.config.get_int('REPLICATION_RECHECK_INTERVAL')
while True:
- conn = connect(args.config.get_libpq_dsn())
- start = dt.datetime.now(dt.timezone.utc)
- state = replication.update(conn, params)
- if state is not replication.UpdateState.NO_CHANGES:
- status.log_status(conn, start, 'import')
- batchdate, _, _ = status.get_status(conn)
- conn.close()
+ with connect(args.config.get_libpq_dsn()) as conn:
+ start = dt.datetime.now(dt.timezone.utc)
+ state = replication.update(conn, params)
+ if state is not replication.UpdateState.NO_CHANGES:
+ status.log_status(conn, start, 'import')
+ batchdate, _, _ = status.get_status(conn)
if state is not replication.UpdateState.NO_CHANGES and args.do_index:
index_start = dt.datetime.now(dt.timezone.utc)
indexer.index_boundaries(0, 30)
indexer.index_by_rank(0, 30)
- conn = connect(args.config.get_libpq_dsn())
- status.set_indexed(conn, True)
- status.log_status(conn, index_start, 'index')
- conn.close()
+ with connect(args.config.get_libpq_dsn()) as conn:
+ status.set_indexed(conn, True)
+ status.log_status(conn, index_start, 'index')
else:
index_start = None
"""
Specialised connection and cursor functions.
"""
+import contextlib
import logging
import psycopg2
def connect(dsn):
""" Open a connection to the database using the specialised connection
- factory.
+ factory. The returned object may be used in conjunction with 'with'.
+ When used outside a context manager, use the `connection` attribute
+ to get the connection.
"""
try:
- return psycopg2.connect(dsn, connection_factory=_Connection)
+ conn = psycopg2.connect(dsn, connection_factory=_Connection)
+ ctxmgr = contextlib.closing(conn)
+ ctxmgr.connection = conn
+ return ctxmgr
except psycopg2.OperationalError as err:
raise UsageError("Cannot connect to database: {}".format(err)) from err
""" Run a number of checks on the database and return the status.
"""
try:
- conn = connect(config.get_libpq_dsn())
+ conn = connect(config.get_libpq_dsn()).connection
except UsageError as err:
conn = _BadConnection(str(err))
def temp_db_conn(temp_db):
""" Connection to the test database.
"""
- conn = connection.connect('dbname=' + temp_db)
- yield conn
- conn.close()
+ with connection.connect('dbname=' + temp_db) as conn:
+ yield conn
@pytest.fixture
@pytest.fixture
def db(temp_db):
- conn = connect('dbname=' + temp_db)
- yield conn
- conn.close()
+ with connect('dbname=' + temp_db) as conn:
+ yield conn
def test_connection_table_exists(db, temp_db_cursor):
@pytest.fixture
def db(temp_db, placex_table):
- conn = connect('dbname=' + temp_db)
- yield conn
- conn.close()
+ with connect('dbname=' + temp_db) as conn:
+ yield conn
def test_analyse_indexing_no_objects(db):
with pytest.raises(UsageError):
assert 1 == chkdb.check_database(def_config)
+def test_check_database_fatal_test(def_config, temp_db):
+ assert 1 == chkdb.check_database(def_config)
+
+
def test_check_conection_good(temp_db_conn, def_config):
assert chkdb.check_connection(temp_db_conn, def_config) == chkdb.CheckState.OK
@pytest.fixture
def db(temp_db):
- conn = connect('dbname=' + temp_db)
- yield conn
- conn.close()
+ with connect('dbname=' + temp_db) as conn:
+ yield conn
@pytest.fixture
def db_with_tables(db):