Files
openstreetmap-osm2pgsql/tests/bdd/environment.py
2025-01-13 09:07:30 +01:00

163 lines
5.7 KiB
Python

# SPDX-License-Identifier: GPL-2.0-or-later
#
# This file is part of osm2pgsql (https://osm2pgsql.org/).
#
# Copyright (C) 2006-2025 by the osm2pgsql developer community.
# For a full list of authors see the git log.
from contextlib import closing
from pathlib import Path
import subprocess
import tempfile
import importlib.util
import io
from importlib.machinery import SourceFileLoader
from behave import *
try:
import psycopg2 as psycopg
from psycopg2 import sql
except ImportError:
import psycopg
from psycopg import sql
from steps.geometry_factory import GeometryFactory
from steps.replication_server_mock import ReplicationServerMock
TEST_BASE_DIR = (Path(__file__) / '..' / '..').resolve()
# The following parameters can be changed on the command line using
# the -D parameter. Example:
#
# behave -DBINARY=/tmp/my-builddir/osm2pgsql -DKEEP_TEST_DB
USER_CONFIG = {
'BINARY': (TEST_BASE_DIR / '..' / 'build' / 'osm2pgsql').resolve(),
'REPLICATION_SCRIPT': (TEST_BASE_DIR / '..' / 'scripts' / 'osm2pgsql-replication').resolve(),
'TEST_DATA_DIR': TEST_BASE_DIR / 'data',
'SRC_DIR': (TEST_BASE_DIR / '..').resolve(),
'KEEP_TEST_DB': False,
'TEST_DB': 'osm2pgsql-test',
'HAVE_TABLESPACE': True,
'HAVE_PROJ': True
}
use_step_matcher('re')
def _connect_db(context, dbname):
""" Connect to the given database and return the connection
object as a context manager that automatically closes.
Note that the connection does not commit automatically.
"""
if psycopg.__version__.startswith('2'):
conn = psycopg.connect(dbname=dbname)
conn.autocommit = True
return closing(conn)
return psycopg.connect(dbname=dbname, autocommit=True)
def _drop_db(context, dbname, recreate_immediately=False):
""" Drop the database with the given name if it exists.
"""
with _connect_db(context, 'postgres') as conn:
with conn.cursor() as cur:
db = sql.Identifier(dbname)
cur.execute(sql.SQL('DROP DATABASE IF EXISTS {}').format(db))
if recreate_immediately:
cur.execute(sql.SQL('CREATE DATABASE {}').format(db))
def before_all(context):
# logging setup
context.config.setup_logging()
# set up -D options
for k,v in USER_CONFIG.items():
context.config.userdata.setdefault(k, v)
if context.config.userdata['HAVE_TABLESPACE']:
with _connect_db(context, 'postgres') as conn:
with conn.cursor() as cur:
cur.execute("""SELECT spcname FROM pg_tablespace
WHERE spcname = 'tablespacetest'""")
context.config.userdata['HAVE_TABLESPACE'] = cur.rowcount > 0
cur.execute("""SELECT setting FROM pg_settings
WHERE name = 'server_version_num'""")
context.config.userdata['PG_VERSION'] = int(cur.fetchone()[0])
# Get the osm2pgsql configuration
proc = subprocess.Popen([str(context.config.userdata['BINARY']), '--version'],
stdout=subprocess.PIPE, stderr=subprocess.PIPE)
_, serr = proc.communicate()
ver_info = serr.decode('utf-8')
if proc.returncode != 0:
raise RuntimeError('Cannot run osm2pgsql')
if context.config.userdata['HAVE_PROJ']:
context.config.userdata['HAVE_PROJ'] = 'Proj [disabled]' not in ver_info
context.test_data_dir = Path(context.config.userdata['TEST_DATA_DIR']).resolve()
context.default_data_dir = Path(context.config.userdata['SRC_DIR']).resolve()
# Set up replication script.
replicationfile = str(Path(context.config.userdata['REPLICATION_SCRIPT']).resolve())
spec = importlib.util.spec_from_loader('osm2pgsql_replication',
SourceFileLoader('osm2pgsql_replication',
replicationfile))
assert spec, f"File not found: {replicationfile}"
context.osm2pgsql_replication = importlib.util.module_from_spec(spec)
spec.loader.exec_module(context.osm2pgsql_replication)
def before_scenario(context, scenario):
""" Set up a fresh, empty test database.
"""
if 'config.have_proj' in scenario.tags and not context.config.userdata['HAVE_PROJ']:
scenario.skip("Generic proj library not configured.")
context.db = use_fixture(test_db, context)
context.import_file = None
context.import_data = {'n': [], 'w': [], 'r': []}
context.osm2pgsql_params = []
context.workdir = use_fixture(working_directory, context)
context.geometry_factory = GeometryFactory()
context.osm2pgsql_replication.ReplicationServer = ReplicationServerMock()
context.urlrequest_responses = {}
def _mock_urlopen(request):
if not request.full_url in context.urlrequest_responses:
raise urllib.error.URLError('Unknown URL')
return closing(io.BytesIO(context.urlrequest_responses[request.full_url].encode('utf-8')))
context.osm2pgsql_replication.urlrequest.urlopen = _mock_urlopen
@fixture
def test_db(context, **kwargs):
dbname = context.config.userdata['TEST_DB']
_drop_db(context, dbname, recreate_immediately=True)
with _connect_db(context, dbname) as conn:
with conn.cursor() as cur:
cur.execute('CREATE EXTENSION postgis')
cur.execute('CREATE EXTENSION hstore')
yield conn
if not context.config.userdata['KEEP_TEST_DB']:
_drop_db(context, dbname)
@fixture
def working_directory(context, **kwargs):
with tempfile.TemporaryDirectory() as tmpdir:
yield Path(tmpdir)
def before_tag(context, tag):
if tag == 'needs-pg-index-includes':
if context.config.userdata['PG_VERSION'] < 110000:
context.scenario.skip("No index includes in PostgreSQL < 11")