mirror of
https://github.com/osm2pgsql-dev/osm2pgsql.git
synced 2026-01-14 03:17:03 +00:00
835 lines
30 KiB
Python
Executable File
835 lines
30 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
|
|
# SPDX-License-Identifier: GPL-2.0-or-later
|
|
#
|
|
# This file is part of osm2pgsql (https://osm2pgsql.org/).
|
|
#
|
|
# Copyright (C) 2025 by the osm2pgsql developer community.
|
|
# For a full list of authors see the git log.
|
|
"""
|
|
Test runner for BDD-style integration tests.
|
|
|
|
See osm2pgsql manual for more information on osm2pgsql style testing.
|
|
"""
|
|
import logging
|
|
import sys
|
|
import tempfile
|
|
import math
|
|
import re
|
|
import os
|
|
import contextlib
|
|
import json
|
|
import datetime as dt
|
|
from decimal import Decimal
|
|
from subprocess import Popen, PIPE
|
|
from argparse import ArgumentParser, RawDescriptionHelpFormatter
|
|
from pathlib import Path
|
|
import importlib.util
|
|
import io
|
|
from importlib.machinery import SourceFileLoader
|
|
|
|
from behave import given, when, then, use_step_matcher, use_fixture, fixture
|
|
from behave.runner import ModelRunner, Context
|
|
from behave.formatter.base import StreamOpener
|
|
from behave.formatter.pretty import PrettyFormatter
|
|
from behave import runner_util
|
|
from behave.configuration import Configuration
|
|
|
|
LOG = logging.getLogger()
|
|
|
|
import psycopg
|
|
from psycopg import sql
|
|
|
|
use_step_matcher('re')
|
|
|
|
OBJECT_ORDER = {'n': 1, 'w': 2, 'r': 3}
|
|
|
|
def opl_sort(line):
|
|
oid = line.split(' ', 1)[0]
|
|
return OBJECT_ORDER[oid[0]], int(oid[1:])
|
|
|
|
#################### Replication mock ##############################
|
|
|
|
class ReplicationServerMock:
|
|
|
|
def __init__(self, base_url, state_infos):
|
|
self.expected_base_url = base_url
|
|
self.state_infos = state_infos
|
|
|
|
def __call__(self, base_url):
|
|
assert base_url == self.expected_base_url,\
|
|
f"Wrong replication service called. Expected '{self.expected_base_url}', got '{base_url}'"
|
|
return self
|
|
|
|
def get_state_info(self, seq=None, retries=2):
|
|
assert self.state_infos, 'Replication mock not properly set up'
|
|
if seq is None:
|
|
return self.state_infos[-1]
|
|
|
|
for info in self.state_infos:
|
|
if info.sequence == seq:
|
|
return info
|
|
|
|
return None
|
|
|
|
def timestamp_to_sequence(self, timestamp, balanced_search=False):
|
|
assert self.state_infos, 'Replication mock not properly set up'
|
|
|
|
if timestamp < self.state_infos[0].timestamp:
|
|
return self.state_infos[0].sequence
|
|
|
|
prev = self.state_infos[0]
|
|
for info in self.state_infos:
|
|
if timestamp >= prev.timestamp and timestamp < info.timestamp:
|
|
return prev.sequence
|
|
prev = info
|
|
|
|
return prev.sequence
|
|
|
|
def apply_diffs(self, handler, start_id, max_size=1024, idx="", simplify=True):
|
|
if start_id > self.state_infos[-1].sequence:
|
|
return None
|
|
|
|
numdiffs = int((max_size + 1023)/1024)
|
|
return min(self.state_infos[-1].sequence, start_id + numdiffs - 1)
|
|
|
|
|
|
# Replication module is optional
|
|
_repfl_spec = importlib.util.spec_from_loader(
|
|
'osm2pgsql_replication',
|
|
SourceFileLoader('osm2pgsql_replication',
|
|
str(Path(__file__, '..', 'osm2pgsql-replication').resolve())))
|
|
|
|
if _repfl_spec:
|
|
osm2pgsql_replication = importlib.util.module_from_spec(_repfl_spec)
|
|
_repfl_spec.loader.exec_module(osm2pgsql_replication)
|
|
|
|
from osmium.replication.server import OsmosisState
|
|
else:
|
|
osm2pgsql_replication = None
|
|
|
|
#################### hooks #########################################
|
|
|
|
def hook_before_all(context):
|
|
context.config.setup_logging(logging.INFO)
|
|
|
|
# Feature check: table spaces
|
|
if context.user_args.test_tablespace == 'auto':
|
|
with context.connect_db('postgres') as conn:
|
|
with conn.cursor() as cur:
|
|
cur.execute("""SELECT spcname FROM pg_tablespace
|
|
WHERE spcname = 'tablespacetest'""")
|
|
context.user_args.test_tablespace = cur.rowcount > 0
|
|
LOG.info('Check if tablespaces are available: %s',
|
|
'yes' if context.user_args.test_tablespace else 'no')
|
|
else:
|
|
context.user_args.test_tablespace = context.user_args.test_tablespace == 'yes'
|
|
|
|
# Test that osm2pgsql works.
|
|
proc = Popen([context.user_args.osm2pgsql_binary, '--version'],
|
|
stdout=PIPE, stderr=PIPE)
|
|
_, serr = proc.communicate()
|
|
osm2pgsql_version = serr.decode('utf-8')
|
|
if proc.returncode != 0:
|
|
LOG.critical("Could not run osm2pgsql. Error:\n%s", serr)
|
|
LOG.critical("osm2pgsql binary used: %s", context.user_args.osm2pgsql_binary)
|
|
raise RuntimeError('Error running osm2pgsql')
|
|
|
|
LOG.info('Check if proj is available: %s',
|
|
'yes' if context.user_args.test_proj else 'no')
|
|
|
|
# Feature check: proj
|
|
if context.user_args.test_proj == 'auto':
|
|
context.user_args.test_proj = 'Proj [disabled]' not in osm2pgsql_version
|
|
else:
|
|
context.user_args.test_proj = context.user_args.test_proj == 'yes'
|
|
|
|
use_fixture(template_test_db, context)
|
|
|
|
|
|
|
|
def hook_before_scenario(context, scenario):
|
|
if 'config.have_proj' in scenario.tags and not context.user_args.test_proj:
|
|
scenario.skip("Generic proj library not configured.")
|
|
|
|
context.db = use_fixture(test_db, context)
|
|
context.import_file = None
|
|
context.import_data = None
|
|
context.osm2pgsql_params = {'-d': context.user_args.test_db}
|
|
context.osm2pgsql_returncode = None
|
|
context.workdir = use_fixture(working_directory, context)
|
|
context.nodes = NodeStore()
|
|
context.sql_statements = {}
|
|
context.urlrequest_responses = {}
|
|
osm2pgsql_replication.ReplicationServer = None
|
|
|
|
def _mock_urlopen(request):
|
|
if not request.full_url in context.urlrequest_responses:
|
|
raise urllib.error.URLError('Unknown URL')
|
|
|
|
return contextlib.closing(io.BytesIO(context.urlrequest_responses[request.full_url].encode('utf-8')))
|
|
|
|
osm2pgsql_replication.urlrequest.urlopen = _mock_urlopen
|
|
|
|
|
|
|
|
#################### fixtures ######################################
|
|
|
|
@fixture
|
|
def template_test_db(context, **kwargs):
|
|
context.drop_db(context.user_args.template_test_db, recreate_template='default')
|
|
|
|
with context.connect_db(context.user_args.template_test_db) as conn:
|
|
conn.execute('CREATE EXTENSION postgis')
|
|
conn.execute('CREATE EXTENSION hstore')
|
|
|
|
yield context.user_args.template_test_db
|
|
|
|
context.drop_db(context.user_args.template_test_db)
|
|
|
|
|
|
@fixture
|
|
def test_db(context, **kwargs):
|
|
context.drop_db(recreate_template=context.user_args.template_test_db)
|
|
|
|
with context.connect_db() as conn:
|
|
yield conn
|
|
|
|
if not context.user_args.keep_test_db:
|
|
context.drop_db()
|
|
|
|
|
|
@fixture
|
|
def working_directory(context, **kwargs):
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
yield Path(tmpdir)
|
|
|
|
################### Node location creation #########################
|
|
|
|
class NodeStore:
|
|
grid = {}
|
|
|
|
def set_grid(self, lines, grid_step, origin_x, origin_y):
|
|
self.grid = {}
|
|
origin_y -= grid_step * (len(lines) - 1)
|
|
|
|
ndigits = 1
|
|
step = grid_step
|
|
while step < 0:
|
|
ndigits += 1
|
|
step /= 10
|
|
|
|
for y, line in enumerate(lines):
|
|
for x, pt in enumerate(line):
|
|
if pt.isdigit():
|
|
self.grid[int(pt)] = (round(origin_x + x * grid_step, ndigits),
|
|
round(origin_y + y * grid_step, ndigits))
|
|
|
|
def get_as_opl(self):
|
|
return [f"n{i} x{x} y{y}" for i, (x, y) in self.grid.items()]
|
|
|
|
def add_coordinates(self, lines):
|
|
for line in lines:
|
|
if line.startswith('n') and ' x' not in line:
|
|
nid = int(line.split(' ', 1)[0][1:])
|
|
assert nid in self.grid, \
|
|
f"OPL error. Node {nid} has no coordinates and is not in grid."
|
|
x, y = self.grid[nid]
|
|
yield f"{line} x{x} y{y}"
|
|
else:
|
|
yield line
|
|
|
|
def parse_point(self, pt):
|
|
pt = pt.strip()
|
|
if ' ' in pt:
|
|
return list(map(float, pt.split(' ', 1)))
|
|
return self.grid[int(pt)]
|
|
|
|
|
|
################### ResultComparison ###############################
|
|
|
|
class ResultCompare:
|
|
|
|
def __init__(self, heading, nodes):
|
|
self.nodes = nodes
|
|
if '!' in heading:
|
|
self.name, self.fmt = heading.rsplit('!', 1)
|
|
if self.fmt.startswith(':'):
|
|
self.compare = self._intcompare_fmt
|
|
elif self.fmt.startswith('~'):
|
|
if self.fmt.endswith('%'):
|
|
rel_tol = float(self.fmt[1:-1]) / 100
|
|
self.compare = lambda exp, val: math.isclose(float(exp), val, rel_tol=rel_tol)
|
|
else:
|
|
abs_tol = float(self.fmt[1:])
|
|
self.compare = lambda exp, val: math.isclose(float(exp), val, abs_tol=abs_tol)
|
|
else:
|
|
self.compare = getattr(self, f"_compare_{self.fmt}", None)
|
|
assert self.compare is not None, f"Unknown formatter {self.fmt}"
|
|
else:
|
|
self.name = heading
|
|
self.fmt = None
|
|
self.compare = lambda exp, val: str(val) == exp
|
|
|
|
def as_select(self):
|
|
if self.fmt == 'geo':
|
|
return f"ST_AsText({self.name})"
|
|
|
|
return self.name
|
|
|
|
def equals(self, expected, value):
|
|
if expected == 'NULL':
|
|
return value is None
|
|
|
|
return self.compare(expected, value)
|
|
|
|
def _intcompare_fmt(self, expected, value):
|
|
return expected == f"{{{self.fmt}}}".format(value)
|
|
|
|
def _compare_i(self, expected, value):
|
|
return expected.lower() == str(value).lower()
|
|
|
|
def _compare_re(self, expected, value):
|
|
return re.fullmatch(expected, str(value)) is not None
|
|
|
|
def _compare_substr(self, expected, value):
|
|
return expected in str(value)
|
|
|
|
def _compare_json(self, expected, value):
|
|
return json.loads(expected) == value
|
|
|
|
def _compare_geo(self, expected, value):
|
|
m = re.fullmatch(r'([A-Z]+)\((.*)\)', value)
|
|
|
|
return self._eq_geom(expected, m[1], m[2]) if m else False
|
|
|
|
def _eq_geom(self, bdd_geom, pg_type, pg_coords):
|
|
# MULTI* geometries
|
|
if bdd_geom.startswith('[') and bdd_geom.endswith(']'):
|
|
bdd_parts = bdd_geom[1:-1].split(';')
|
|
pg_parts = pg_coords[1:-1].split('),(')
|
|
return pg_type.startswith('MULTI') \
|
|
and len(bdd_parts) == len(pg_parts) \
|
|
and all(self._eq_geom(b.strip(), pg_type[5:], g.strip())
|
|
for b, g in zip(bdd_parts, pg_parts))
|
|
|
|
# GEOMETRYCOLLECTIONS
|
|
if bdd_geom.startswith('{') and bdd_geom.endswith('}'):
|
|
bdd_parts = bdd_geom[1:-1].split(';')
|
|
pg_parts = list(map(lambda s: re.fullmatch(r'([A-Z]+)\(([^A-Z]*)\)', s),
|
|
re.findall('[A-Z]+[^A-Z]+[^,A-Z]', pg_coords)))
|
|
return pg_type.startswith('GEOMETRYCOLLECTION')\
|
|
and len(bdd_parts) == len(pg_parts)\
|
|
and all(g is not None and self._eq_geom(b.strip(), g[1], g[2])
|
|
for b, g in zip(bdd_parts, pg_parts))
|
|
|
|
# POINT
|
|
if ',' not in bdd_geom:
|
|
return pg_type == 'POINT' and self._eq_point(bdd_geom, pg_coords)
|
|
|
|
# LINESTRING
|
|
if '(' not in bdd_geom:
|
|
return pg_type == 'LINESTRING' \
|
|
and all(self._eq_point(b, p) for b, p
|
|
in zip((g.strip() for g in bdd_geom.split(',')),
|
|
(g.strip() for g in pg_coords.split(','))))
|
|
|
|
# POLYGON
|
|
if pg_type != 'POLYGON':
|
|
return False
|
|
# Polygon comparison is tricky because the polygons don't necessarily
|
|
# end at the same point or have the same winding order.
|
|
# Brute force all possible variants of the expected polygon
|
|
bdd_parts = re.findall(r'\([^)]+\)', bdd_geom)
|
|
pg_parts = [g.strip() for g in pg_coords[1:-1].split('),(')]
|
|
return len(bdd_parts) == len(pg_parts) \
|
|
and all(self._eq_ring(*parts) for parts in zip(bdd_parts, pg_parts))
|
|
|
|
def _eq_point(self, bdd_pt, pg_pt):
|
|
exp_geom = self.nodes.parse_point(bdd_pt)
|
|
pg_geom = list(map(float, pg_pt.split(' ')))
|
|
return len(exp_geom) == len(pg_geom) \
|
|
and all(math.isclose(e, p, rel_tol=0.000001) for e, p in zip(exp_geom, pg_geom))
|
|
|
|
def _eq_ring(self, bdd_ring, pg_ring):
|
|
bdd_pts = [g.strip() for g in bdd_ring[1:-1].split(',')]
|
|
pg_pts = [g.strip() for g in pg_ring.split(',')]
|
|
if bdd_pts[0] != bdd_pts[-1]:
|
|
raise RuntimeError(f"Invalid polygon {bdd_geom}. "
|
|
"First and last point need to be the same")
|
|
if len(bdd_pts) != len(pg_pts):
|
|
return False
|
|
|
|
for line in (bdd_pts[:-1], bdd_pts[-1:0:-1]):
|
|
for i in range(len(line)):
|
|
if all(self._eq_point(p1, p2) for p1, p2 in
|
|
zip(line[i:] + line[:i], pg_pts)):
|
|
return True
|
|
|
|
return False
|
|
|
|
################### Steps: Database setup ##########################
|
|
|
|
@given("the database schema (?P<schema>.+)")
|
|
def create_db_schema(context, schema):
|
|
with context.db.cursor() as cur:
|
|
cur.execute("CREATE SCHEMA " + schema)
|
|
|
|
|
|
@when("deleting table (?P<table>.+)")
|
|
def delete_table(context, table):
|
|
with context.db.cursor() as cur:
|
|
cur.execute("DROP TABLE " + table)
|
|
|
|
|
|
################### Steps: OSM data ################################
|
|
|
|
@given("the input file '(?P<osm_file>.+)'")
|
|
def osm_set_import_file(context, osm_file):
|
|
assert context.import_data is None, \
|
|
"Import file cannot be used together with inline data."
|
|
pfile = Path(osm_file)
|
|
if pfile.is_absolute():
|
|
context.import_file = pfile
|
|
else:
|
|
basedir = context.user_args.test_data_dir or Path(context.feature.filename).parent
|
|
context.import_file = (basedir / osm_file).resolve()
|
|
|
|
|
|
@given("the OSM data")
|
|
def osm_define_data(context):
|
|
assert context.import_file is None, \
|
|
"Inline data cannot be used together with an import file."
|
|
if context.text.strip():
|
|
context.append_osm_data(context.nodes.add_coordinates(context.text.split('\n')))
|
|
else:
|
|
context.append_osm_data([])
|
|
|
|
|
|
@given("the OSM data format string")
|
|
def osm_define_data(context):
|
|
assert context.import_file is None, \
|
|
"Inline data cannot be used together with an import file."
|
|
data = eval('f"""' + context.text + '"""')
|
|
|
|
context.append_osm_data(context.nodes.add_coordinates(data.split('\n')))
|
|
|
|
|
|
@given("the (?P<step>[0-9.]+ )?grid(?: with origin (?P<origin_x>[0-9.-]+) (?P<origin_y>[0-9.-]+))?")
|
|
def osm_define_node_grid(context, step, origin_x, origin_y):
|
|
step = float(step.strip()) if step else 0.1
|
|
x = float(origin_x) if origin_x else 20.0
|
|
y = float(origin_y) if origin_y else 20.0
|
|
|
|
assert x > -180.0 and x < 180.0
|
|
assert y > -90.0 and y < 90.0
|
|
|
|
context.nodes.set_grid([context.table.headings] + [list(h) for h in context.table], step, x, y)
|
|
context.append_osm_data(context.nodes.get_as_opl())
|
|
|
|
################### Steps: Style file ##############################
|
|
|
|
@given("the style file '(?P<style>.+)'")
|
|
def setup_style_file(context, style):
|
|
sfile = Path(style)
|
|
|
|
if sfile.is_absolute():
|
|
assert sfile.is_file()
|
|
elif context.user_args.style_data_dir is not None \
|
|
and (context.user_args.style_data_dir / sfile).is_file():
|
|
sfile = context.user_args.style_data_dir / sfile
|
|
elif context.user_args.test_data_dir is not None \
|
|
and (context.user_args.test_data_dir / sfile).is_file():
|
|
sfile = context.user_args.test_data_dir / sfile
|
|
else:
|
|
sfile = Path(context.feature.filename).parent / sfile
|
|
assert sfile.is_file()
|
|
context.osm2pgsql_params['-S'] = str(sfile.resolve())
|
|
|
|
|
|
@given("the lua style")
|
|
def setup_style_inline(context):
|
|
outfile = context.workdir / 'inline_style.lua'
|
|
outfile.write_text(context.text)
|
|
|
|
context.osm2pgsql_params['-S'] = str(outfile)
|
|
|
|
|
|
################### Steps: Running osm2pgsql #######################
|
|
|
|
@when(r"running osm2pgsql(?P<output> \w+)?(?: with parameters)?")
|
|
def execute_osm2pgsql(context, output):
|
|
assert output in (' flex', ' pgsql', ' null', None)
|
|
|
|
if output is not None:
|
|
context.osm2pgsql_params['-O'] = output.strip()
|
|
if output == ' pgsql' and '-S' not in context.osm2pgsql_params:
|
|
context.osm2pgsql_params['-S'] = '{STYLE_DATA_DIR}/default.style'
|
|
|
|
cmdline = [context.user_args.osm2pgsql_binary]
|
|
|
|
test_dir = (context.user_args.test_data_dir or Path('.')).resolve()
|
|
style_dir = (context.user_args.style_data_dir or Path('.')).resolve()
|
|
def _template(param):
|
|
return param.replace('{TEST_DATA_DIR}', str(test_dir))\
|
|
.replace('{STYLE_DATA_DIR}', str(style_dir))\
|
|
.replace('{TEST_DB}', context.user_args.test_db)
|
|
|
|
if context.table:
|
|
assert not any('<' in h for h in context.table.headings), \
|
|
"Substition in the first line of a table are not supported."
|
|
cmdline.extend(_template(h) for h in context.table.headings if h)
|
|
for row in context.table:
|
|
cmdline.extend(_template(c) for c in row if c)
|
|
|
|
for k, v in context.osm2pgsql_params.items():
|
|
if k not in cmdline:
|
|
cmdline.extend((k, _template(v)))
|
|
|
|
if not context.user_args.test_tablespace\
|
|
and any(p.startswith('--tablespace') for p in cmdline):
|
|
context.scenario.skip('tablespace tablespacetest not available')
|
|
return
|
|
|
|
if context.import_data is not None:
|
|
data_stdin = '\n'.join(sorted(context.import_data.values(), key=opl_sort)).encode('utf-8')
|
|
context.import_data = None
|
|
cmdline.extend(('-r', 'opl', '-'))
|
|
else:
|
|
assert context.import_file is not None, "No input data given."
|
|
cmdline.append(str(context.import_file))
|
|
context.import_file = None
|
|
data_stdin = None
|
|
|
|
proc = Popen(cmdline, cwd=str(context.workdir),
|
|
stdin=PIPE, stdout=PIPE, stderr=PIPE)
|
|
|
|
outdata = proc.communicate(input=data_stdin)
|
|
context.osm2pgsql_cmdline = ' '.join(cmdline)
|
|
context.osm2pgsql_outdata = [d.decode('utf-8').replace('\\n', '\n') for d in outdata]
|
|
context.osm2pgsql_returncode = proc.returncode
|
|
context.osm2pgsql_params = {'-d': context.user_args.test_db}
|
|
|
|
@then("execution is successful")
|
|
def osm2pgsql_check_success(context):
|
|
assert context.osm2pgsql_returncode == 0, \
|
|
f"osm2pgsql failed with error code {context.osm2pgsql_returncode}.\n"\
|
|
f"Command line: {context.osm2pgsql_cmdline}\n"\
|
|
f"Output:\n{context.osm2pgsql_outdata[0]}\n{context.osm2pgsql_outdata[1]}\n"
|
|
|
|
@then(r"execution fails(?: with return code (?P<expected>\d+))?")
|
|
def osm2pgsql_check_failure(context, expected):
|
|
retcode = context.osm2pgsql_returncode
|
|
assert retcode != 0, "osm2pgsql unexpectedly succeeded"
|
|
|
|
if expected:
|
|
assert retcode == int(expected), \
|
|
f"osm2pgsql failed with return code {retcode} instead of {expected}\n"\
|
|
f"Output:\n{context.osm2pgsql_outdata[0]}\n{context.osm2pgsql_outdata[1]}\n"
|
|
|
|
@then(r"the (?P<kind>\w+) output contains")
|
|
def check_program_output(context, kind):
|
|
if kind == 'error':
|
|
s = context.osm2pgsql_outdata[1]
|
|
elif kind == 'standard':
|
|
s = context.osm2pgsql_outdata[0]
|
|
else:
|
|
assert not "Expect one of error, standard"
|
|
|
|
for line in context.text.split('\n'):
|
|
line = line.strip()
|
|
assert line in s,\
|
|
f"Output '{line}' not found in {kind} output:\n{s}\n"
|
|
|
|
################### Steps: Running Replication #####################
|
|
|
|
@given("the replication service at (?P<base_url>.*)")
|
|
def setup_replication_mock(context, base_url):
|
|
if osm2pgsql_replication is None:
|
|
context.scenario.skip("Replication binary not available. Skip.")
|
|
return
|
|
|
|
if context.table:
|
|
state_infos = \
|
|
[OsmosisState(int(row[0]),
|
|
dt.datetime.strptime(row[1], '%Y-%m-%dT%H:%M:%SZ').replace(tzinfo=dt.timezone.utc))
|
|
for row in context.table]
|
|
else:
|
|
state_infos = []
|
|
osm2pgsql_replication.ReplicationServer = ReplicationServerMock(base_url, state_infos)
|
|
|
|
|
|
@given("the URL (?P<base_url>.*) returns")
|
|
def mock_url_response(context, base_url):
|
|
context.urlrequest_responses[base_url] = context.text
|
|
|
|
|
|
@when("running osm2pgsql-replication")
|
|
def execute_osm2pgsql_replication(context):
|
|
assert osm2pgsql_replication is not None
|
|
assert osm2pgsql_replication.ReplicationServer is not None
|
|
|
|
cmdline = []
|
|
|
|
test_dir = (context.user_args.test_data_dir or Path('.')).resolve()
|
|
style_dir = (context.user_args.style_data_dir or Path('.')).resolve()
|
|
def _template(param):
|
|
return param.replace('{TEST_DATA_DIR}', str(test_dir))\
|
|
.replace('{STYLE_DATA_DIR}', str(style_dir))\
|
|
.replace('{TEST_DB}', context.user_args.test_db)
|
|
|
|
if context.table:
|
|
assert not any('<' in h for h in context.table.headings), \
|
|
"Substition in the first line of a table are not supported."
|
|
cmdline.extend(_template(h) for h in context.table.headings if h)
|
|
for row in context.table:
|
|
cmdline.extend(_template(c) for c in row if c)
|
|
|
|
if '-d' not in cmdline and '--database' not in cmdline:
|
|
cmdline.extend(('-d', context.user_args.test_db))
|
|
|
|
if cmdline[0] == 'update':
|
|
cmdline.extend(('--osm2pgsql-cmd', context.user_args.osm2pgsql_binary))
|
|
|
|
if '--' not in cmdline:
|
|
cmdline.extend(('--', '-S', str(style_dir / 'default.style')))
|
|
|
|
serr = io.StringIO()
|
|
log_handler = logging.StreamHandler(serr)
|
|
osm2pgsql_replication.LOG.addHandler(log_handler)
|
|
with contextlib.redirect_stdout(io.StringIO()) as sout:
|
|
context.osm2pgsql_returncode = osm2pgsql_replication.main(cmdline)
|
|
osm2pgsql_replication.LOG.removeHandler(log_handler)
|
|
context.osm2pgsql_outdata = [sout.getvalue(), serr.getvalue()]
|
|
|
|
|
|
################### Steps: Inspect database ########################
|
|
|
|
@given("the SQL statement (?P<sql>.+)")
|
|
def db_define_sql_statement(context, sql):
|
|
context.sql_statements[sql] = context.text
|
|
|
|
|
|
@then("there are (?P<exists>no )?tables (?P<tables>.+)")
|
|
def db_table_existance(context, exists, tables):
|
|
for table in tables.split(','):
|
|
table = table.strip()
|
|
if exists == 'no ':
|
|
assert not context.table_exists(table), f"Table '{table}' unexpectedly found"
|
|
else:
|
|
assert context.table_exists(table), f"Table '{table}' not found"
|
|
|
|
|
|
@then("table (?P<table>.+) contains(?P<exact> exactly)?")
|
|
def db_check_table_content(context, table, exact):
|
|
context.execute_steps("Then execution is successful")
|
|
assert context.table_exists(table), f"Table {table} not found in database."
|
|
context.check_select(sql.Identifier(*table.split('.', 1)), exact is not None)
|
|
|
|
|
|
@then("table (?P<table>.+) doesn't contain")
|
|
def db_check_table_content_negative(context, table):
|
|
context.execute_steps("Then execution is successful")
|
|
assert context.table_exists(table), f"Table {table} not found in database."
|
|
context.check_select_not_contained(sql.Identifier(*table.split('.', 1)))
|
|
|
|
|
|
@then(r"table (?P<table>.+) has (?P<row_num>\d+) rows?")
|
|
def db_table_row_count(context, table, row_num):
|
|
context.execute_steps("Then execution is successful")
|
|
assert context.table_exists(table), f"Table {table} not found in database."
|
|
|
|
query = sql.SQL("SELECT count(*) FROM {}").format(sql.Identifier(*table.split('.', 2)))
|
|
for res in context.db.execute(query):
|
|
assert res[0] == int(row_num),\
|
|
f"Table {table}: expected {row_num} rows, got {res[0]}"
|
|
|
|
|
|
@then("statement (?P<stmt>.+) returns(?P<exact> exactly)?")
|
|
def db_check_table_content(context, stmt, exact):
|
|
context.execute_steps("Then execution is successful")
|
|
assert stmt in context.sql_statements
|
|
context.check_select(sql.SQL(f"({context.sql_statements[stmt]}) _statement_sql"),
|
|
exact is not None)
|
|
|
|
################### Context ########################################
|
|
|
|
class Osm2pgsqlContext(Context):
|
|
|
|
def __init__(self, runner, args):
|
|
super().__init__(runner)
|
|
self.user_args = args
|
|
|
|
def connect_db(self, name=None):
|
|
dbname = name or self.user_args.test_db
|
|
return psycopg.connect(dbname=dbname, autocommit=True)
|
|
|
|
def drop_db(self, name=None, recreate_template=None):
|
|
db = sql.Identifier(name or self.user_args.test_db)
|
|
with self.connect_db('postgres') as conn:
|
|
conn.execute(sql.SQL('DROP DATABASE IF EXISTS {}').format(db))
|
|
if recreate_template == 'default':
|
|
conn.execute(sql.SQL('CREATE DATABASE {}').format(db))
|
|
elif recreate_template:
|
|
conn.execute(sql.SQL('CREATE DATABASE {} WITH TEMPLATE {}')
|
|
.format(db, sql.Identifier(recreate_template)))
|
|
|
|
def append_osm_data(self, lines, include_untagged=None):
|
|
if self.import_data is None:
|
|
self.import_data = {}
|
|
for line in lines:
|
|
if (l := line.strip()):
|
|
self.import_data[l.split(' ', 1)[0]] = l
|
|
|
|
def table_exists(self, table):
|
|
sql_params = table.split('.', 1) if '.' in table else ('public', table)
|
|
sql = 'SELECT count(*) FROM {} WHERE schemaname = %s AND tablename = %s'
|
|
|
|
for look_in in ('pg_tables', 'pg_views'):
|
|
for res in self.db.execute(sql.format('pg_tables'), sql_params):
|
|
if res[0] == 1:
|
|
return True
|
|
return False
|
|
|
|
def check_select(self, from_clause, exact):
|
|
rows = [ResultCompare(h, self.nodes) for h in self.table.headings]
|
|
lines = set(range(0, len(self.table.rows)))
|
|
|
|
query = sql.SQL('SELECT {} FROM {}').format(
|
|
sql.SQL(', '.join(f"({r.as_select()}) as c{i}" for i, r in enumerate(rows))),
|
|
from_clause)
|
|
|
|
table_content = ''
|
|
unexpected_rows = []
|
|
for row in self.db.execute(query):
|
|
table_content += f"\n{row}"
|
|
for i in lines:
|
|
for attr, expected, value in zip(rows, self.table[i], row):
|
|
if not attr.equals(expected, value):
|
|
break
|
|
else:
|
|
lines.remove(i)
|
|
break
|
|
else:
|
|
if exact:
|
|
unexpected_rows.append(str(row))
|
|
|
|
assert not lines, \
|
|
"Rows not found:\n" \
|
|
+ '\n'.join(str(self.table[i]) for i in lines) \
|
|
+ "\nTable content:\n" \
|
|
+ table_content
|
|
|
|
assert not unexpected_rows, \
|
|
"Unexpected rows found:\n" + '\n'.join(unexpected_rows)\
|
|
+ "\nTable content:\n" \
|
|
+ table_content
|
|
|
|
def check_select_not_contained(self, from_clause):
|
|
rows = [ResultCompare(h, self.nodes) for h in self.table.headings]
|
|
lines = set(range(0, len(self.table.rows)))
|
|
|
|
query = sql.SQL('SELECT {} FROM {}').format(
|
|
sql.SQL(', '.join(f"({r.as_select()}) as c{i}" for i, r in enumerate(rows))),
|
|
from_clause)
|
|
|
|
table_content = ''
|
|
matching_rows = []
|
|
for row in self.db.execute(query):
|
|
table_content += f"\n{row}"
|
|
for i in lines:
|
|
for attr, expected, value in zip(rows, self.table[i], row):
|
|
if not attr.equals(expected, value):
|
|
break
|
|
else:
|
|
matching_rows.append(str(row))
|
|
break
|
|
|
|
assert not matching_rows, \
|
|
"Matching rows found:\n" + '\n'.join(matching_rows)\
|
|
+ "\nFull table content:\n" \
|
|
+ table_content
|
|
|
|
|
|
#################### runner and main ###############################
|
|
|
|
class Osm2pgsqlRunner(ModelRunner):
|
|
|
|
def __init__(self, config, args):
|
|
super().__init__(config)
|
|
self.feature_locations = runner_util.collect_feature_locations(args.features)
|
|
self.hooks = {
|
|
'before_all' : hook_before_all,
|
|
'before_scenario': hook_before_scenario
|
|
}
|
|
self.context = Osm2pgsqlContext(self, args)
|
|
|
|
def run(self):
|
|
features = runner_util.parse_features(self.feature_locations)
|
|
self.features.extend(features)
|
|
|
|
stream_opener = StreamOpener(stream=sys.stdout)
|
|
self.formatters = [PrettyFormatter(stream_opener, self.config)]
|
|
return self.run_model()
|
|
|
|
|
|
def get_parser():
|
|
parser = ArgumentParser(description=__doc__,
|
|
prog='osm2pgsql-test-style',
|
|
formatter_class=RawDescriptionHelpFormatter)
|
|
parser.add_argument('features', nargs='+',
|
|
help='Feature files or paths')
|
|
parser.add_argument('--osm2pgsql-binary',
|
|
help='osm2pgsql binary to use for testing (default: osm2pgsql)')
|
|
parser.add_argument('--test-data-dir', type=Path,
|
|
help='(optional) directory to search for test data')
|
|
parser.add_argument('--style-data-dir', type=Path,
|
|
help='(optional) directory to search for style files')
|
|
parser.add_argument('--test-db', default='osm2pgsql-test',
|
|
help='Name of database to use for testing (default: osm2pgsql-test)')
|
|
parser.add_argument('--template-test-db', default='osm2pgsql-test-template',
|
|
help='Name of database to use for creating the template db '
|
|
'(default: osm2pgsql-test-template)')
|
|
parser.add_argument('--keep-test-db', action='store_true',
|
|
help='Keep the test database around after tests are done')
|
|
parser.add_argument('--test-tablespace', default='auto', choices=['yes', 'no', 'auto'],
|
|
help='Include tests requiring a tablespace')
|
|
parser.add_argument('--test-proj', default='auto', choices=['yes', 'no', 'auto'],
|
|
help='Include tests requiring the proj library')
|
|
|
|
return parser
|
|
|
|
def main(prog_args=None):
|
|
parser = get_parser()
|
|
try:
|
|
args = parser.parse_args(args=prog_args)
|
|
except SystemExit:
|
|
return 1
|
|
|
|
if args.osm2pgsql_binary is None:
|
|
args.osm2pgsql_binary = 'osm2pgsql'
|
|
else:
|
|
args.osm2pgsql_binary = str(Path(args.osm2pgsql_binary).resolve())
|
|
|
|
config = Configuration(command_args=[])
|
|
config.show_skipped = False
|
|
runner = Osm2pgsqlRunner(config, args)
|
|
|
|
failed = runner.run()
|
|
|
|
if runner.undefined_steps:
|
|
LOG.error('Error in feature definition. The following steps are unknown:\n - '
|
|
+ '\n - '.join(f"{s.keyword} {s.name}" for s in runner.undefined_steps))
|
|
|
|
return int(failed)
|
|
|
|
if __name__ == '__main__':
|
|
retcode = main()
|
|
try:
|
|
pass
|
|
except Exception as ex:
|
|
LOG.fatal("Exception during execution: %s", ex)
|
|
retcode = 3
|
|
|
|
sys.exit(retcode)
|