Files
osm2pgsql/scripts/osm2pgsql-test-style
Sarah Hoffmann cc98e6aab6 forward environment variables to osm2pgsql in test
Windows needs external setup to function properly.
2025-12-03 13:24:09 +01:00

835 lines
30 KiB
Python
Executable File

#!/usr/bin/env python3
# SPDX-License-Identifier: GPL-2.0-or-later
#
# This file is part of osm2pgsql (https://osm2pgsql.org/).
#
# Copyright (C) 2025 by the osm2pgsql developer community.
# For a full list of authors see the git log.
"""
Test runner for BDD-style integration tests.
See osm2pgsql manual for more information on osm2pgsql style testing.
"""
import logging
import sys
import tempfile
import math
import re
import os
import contextlib
import json
import datetime as dt
from decimal import Decimal
from subprocess import Popen, PIPE
from argparse import ArgumentParser, RawDescriptionHelpFormatter
from pathlib import Path
import importlib.util
import io
from importlib.machinery import SourceFileLoader
from behave import given, when, then, use_step_matcher, use_fixture, fixture
from behave.runner import ModelRunner, Context
from behave.formatter.base import StreamOpener
from behave.formatter.pretty import PrettyFormatter
from behave import runner_util
from behave.configuration import Configuration
LOG = logging.getLogger()
import psycopg
from psycopg import sql
use_step_matcher('re')
OBJECT_ORDER = {'n': 1, 'w': 2, 'r': 3}
def opl_sort(line):
oid = line.split(' ', 1)[0]
return OBJECT_ORDER[oid[0]], int(oid[1:])
#################### Replication mock ##############################
class ReplicationServerMock:
def __init__(self, base_url, state_infos):
self.expected_base_url = base_url
self.state_infos = state_infos
def __call__(self, base_url):
assert base_url == self.expected_base_url,\
f"Wrong replication service called. Expected '{self.expected_base_url}', got '{base_url}'"
return self
def get_state_info(self, seq=None, retries=2):
assert self.state_infos, 'Replication mock not properly set up'
if seq is None:
return self.state_infos[-1]
for info in self.state_infos:
if info.sequence == seq:
return info
return None
def timestamp_to_sequence(self, timestamp, balanced_search=False):
assert self.state_infos, 'Replication mock not properly set up'
if timestamp < self.state_infos[0].timestamp:
return self.state_infos[0].sequence
prev = self.state_infos[0]
for info in self.state_infos:
if timestamp >= prev.timestamp and timestamp < info.timestamp:
return prev.sequence
prev = info
return prev.sequence
def apply_diffs(self, handler, start_id, max_size=1024, idx="", simplify=True):
if start_id > self.state_infos[-1].sequence:
return None
numdiffs = int((max_size + 1023)/1024)
return min(self.state_infos[-1].sequence, start_id + numdiffs - 1)
# Replication module is optional
_repfl_spec = importlib.util.spec_from_loader(
'osm2pgsql_replication',
SourceFileLoader('osm2pgsql_replication',
str(Path(__file__, '..', 'osm2pgsql-replication').resolve())))
if _repfl_spec:
osm2pgsql_replication = importlib.util.module_from_spec(_repfl_spec)
_repfl_spec.loader.exec_module(osm2pgsql_replication)
from osmium.replication.server import OsmosisState
else:
osm2pgsql_replication = None
#################### hooks #########################################
def hook_before_all(context):
context.config.setup_logging(logging.INFO)
# Feature check: table spaces
if context.user_args.test_tablespace == 'auto':
with context.connect_db('postgres') as conn:
with conn.cursor() as cur:
cur.execute("""SELECT spcname FROM pg_tablespace
WHERE spcname = 'tablespacetest'""")
context.user_args.test_tablespace = cur.rowcount > 0
LOG.info('Check if tablespaces are available: %s',
'yes' if context.user_args.test_tablespace else 'no')
else:
context.user_args.test_tablespace = context.user_args.test_tablespace == 'yes'
# Test that osm2pgsql works.
proc = Popen([context.user_args.osm2pgsql_binary, '--version'],
stdout=PIPE, stderr=PIPE)
_, serr = proc.communicate()
osm2pgsql_version = serr.decode('utf-8')
if proc.returncode != 0:
LOG.critical("Could not run osm2pgsql. Error:\n%s", serr)
LOG.critical("osm2pgsql binary used: %s", context.user_args.osm2pgsql_binary)
raise RuntimeError('Error running osm2pgsql')
LOG.info('Check if proj is available: %s',
'yes' if context.user_args.test_proj else 'no')
# Feature check: proj
if context.user_args.test_proj == 'auto':
context.user_args.test_proj = 'Proj [disabled]' not in osm2pgsql_version
else:
context.user_args.test_proj = context.user_args.test_proj == 'yes'
use_fixture(template_test_db, context)
def hook_before_scenario(context, scenario):
if 'config.have_proj' in scenario.tags and not context.user_args.test_proj:
scenario.skip("Generic proj library not configured.")
context.db = use_fixture(test_db, context)
context.import_file = None
context.import_data = None
context.osm2pgsql_params = {'-d': context.user_args.test_db}
context.osm2pgsql_returncode = None
context.workdir = use_fixture(working_directory, context)
context.nodes = NodeStore()
context.sql_statements = {}
context.urlrequest_responses = {}
osm2pgsql_replication.ReplicationServer = None
def _mock_urlopen(request):
if not request.full_url in context.urlrequest_responses:
raise urllib.error.URLError('Unknown URL')
return contextlib.closing(io.BytesIO(context.urlrequest_responses[request.full_url].encode('utf-8')))
osm2pgsql_replication.urlrequest.urlopen = _mock_urlopen
#################### fixtures ######################################
@fixture
def template_test_db(context, **kwargs):
context.drop_db(context.user_args.template_test_db, recreate_template='default')
with context.connect_db(context.user_args.template_test_db) as conn:
conn.execute('CREATE EXTENSION postgis')
conn.execute('CREATE EXTENSION hstore')
yield context.user_args.template_test_db
context.drop_db(context.user_args.template_test_db)
@fixture
def test_db(context, **kwargs):
context.drop_db(recreate_template=context.user_args.template_test_db)
with context.connect_db() as conn:
yield conn
if not context.user_args.keep_test_db:
context.drop_db()
@fixture
def working_directory(context, **kwargs):
with tempfile.TemporaryDirectory() as tmpdir:
yield Path(tmpdir)
################### Node location creation #########################
class NodeStore:
grid = {}
def set_grid(self, lines, grid_step, origin_x, origin_y):
self.grid = {}
origin_y -= grid_step * (len(lines) - 1)
ndigits = 1
step = grid_step
while step < 0:
ndigits += 1
step /= 10
for y, line in enumerate(lines):
for x, pt in enumerate(line):
if pt.isdigit():
self.grid[int(pt)] = (round(origin_x + x * grid_step, ndigits),
round(origin_y + y * grid_step, ndigits))
def get_as_opl(self):
return [f"n{i} x{x} y{y}" for i, (x, y) in self.grid.items()]
def add_coordinates(self, lines):
for line in lines:
if line.startswith('n') and ' x' not in line:
nid = int(line.split(' ', 1)[0][1:])
assert nid in self.grid, \
f"OPL error. Node {nid} has no coordinates and is not in grid."
x, y = self.grid[nid]
yield f"{line} x{x} y{y}"
else:
yield line
def parse_point(self, pt):
pt = pt.strip()
if ' ' in pt:
return list(map(float, pt.split(' ', 1)))
return self.grid[int(pt)]
################### ResultComparison ###############################
class ResultCompare:
def __init__(self, heading, nodes):
self.nodes = nodes
if '!' in heading:
self.name, self.fmt = heading.rsplit('!', 1)
if self.fmt.startswith(':'):
self.compare = self._intcompare_fmt
elif self.fmt.startswith('~'):
if self.fmt.endswith('%'):
rel_tol = float(self.fmt[1:-1]) / 100
self.compare = lambda exp, val: math.isclose(float(exp), val, rel_tol=rel_tol)
else:
abs_tol = float(self.fmt[1:])
self.compare = lambda exp, val: math.isclose(float(exp), val, abs_tol=abs_tol)
else:
self.compare = getattr(self, f"_compare_{self.fmt}", None)
assert self.compare is not None, f"Unknown formatter {self.fmt}"
else:
self.name = heading
self.fmt = None
self.compare = lambda exp, val: str(val) == exp
def as_select(self):
if self.fmt == 'geo':
return f"ST_AsText({self.name})"
return self.name
def equals(self, expected, value):
if expected == 'NULL':
return value is None
return self.compare(expected, value)
def _intcompare_fmt(self, expected, value):
return expected == f"{{{self.fmt}}}".format(value)
def _compare_i(self, expected, value):
return expected.lower() == str(value).lower()
def _compare_re(self, expected, value):
return re.fullmatch(expected, str(value)) is not None
def _compare_substr(self, expected, value):
return expected in str(value)
def _compare_json(self, expected, value):
return json.loads(expected) == value
def _compare_geo(self, expected, value):
m = re.fullmatch(r'([A-Z]+)\((.*)\)', value)
return self._eq_geom(expected, m[1], m[2]) if m else False
def _eq_geom(self, bdd_geom, pg_type, pg_coords):
# MULTI* geometries
if bdd_geom.startswith('[') and bdd_geom.endswith(']'):
bdd_parts = bdd_geom[1:-1].split(';')
pg_parts = pg_coords[1:-1].split('),(')
return pg_type.startswith('MULTI') \
and len(bdd_parts) == len(pg_parts) \
and all(self._eq_geom(b.strip(), pg_type[5:], g.strip())
for b, g in zip(bdd_parts, pg_parts))
# GEOMETRYCOLLECTIONS
if bdd_geom.startswith('{') and bdd_geom.endswith('}'):
bdd_parts = bdd_geom[1:-1].split(';')
pg_parts = list(map(lambda s: re.fullmatch(r'([A-Z]+)\(([^A-Z]*)\)', s),
re.findall('[A-Z]+[^A-Z]+[^,A-Z]', pg_coords)))
return pg_type.startswith('GEOMETRYCOLLECTION')\
and len(bdd_parts) == len(pg_parts)\
and all(g is not None and self._eq_geom(b.strip(), g[1], g[2])
for b, g in zip(bdd_parts, pg_parts))
# POINT
if ',' not in bdd_geom:
return pg_type == 'POINT' and self._eq_point(bdd_geom, pg_coords)
# LINESTRING
if '(' not in bdd_geom:
return pg_type == 'LINESTRING' \
and all(self._eq_point(b, p) for b, p
in zip((g.strip() for g in bdd_geom.split(',')),
(g.strip() for g in pg_coords.split(','))))
# POLYGON
if pg_type != 'POLYGON':
return False
# Polygon comparison is tricky because the polygons don't necessarily
# end at the same point or have the same winding order.
# Brute force all possible variants of the expected polygon
bdd_parts = re.findall(r'\([^)]+\)', bdd_geom)
pg_parts = [g.strip() for g in pg_coords[1:-1].split('),(')]
return len(bdd_parts) == len(pg_parts) \
and all(self._eq_ring(*parts) for parts in zip(bdd_parts, pg_parts))
def _eq_point(self, bdd_pt, pg_pt):
exp_geom = self.nodes.parse_point(bdd_pt)
pg_geom = list(map(float, pg_pt.split(' ')))
return len(exp_geom) == len(pg_geom) \
and all(math.isclose(e, p, rel_tol=0.000001) for e, p in zip(exp_geom, pg_geom))
def _eq_ring(self, bdd_ring, pg_ring):
bdd_pts = [g.strip() for g in bdd_ring[1:-1].split(',')]
pg_pts = [g.strip() for g in pg_ring.split(',')]
if bdd_pts[0] != bdd_pts[-1]:
raise RuntimeError(f"Invalid polygon {bdd_geom}. "
"First and last point need to be the same")
if len(bdd_pts) != len(pg_pts):
return False
for line in (bdd_pts[:-1], bdd_pts[-1:0:-1]):
for i in range(len(line)):
if all(self._eq_point(p1, p2) for p1, p2 in
zip(line[i:] + line[:i], pg_pts)):
return True
return False
################### Steps: Database setup ##########################
@given("the database schema (?P<schema>.+)")
def create_db_schema(context, schema):
with context.db.cursor() as cur:
cur.execute("CREATE SCHEMA " + schema)
@when("deleting table (?P<table>.+)")
def delete_table(context, table):
with context.db.cursor() as cur:
cur.execute("DROP TABLE " + table)
################### Steps: OSM data ################################
@given("the input file '(?P<osm_file>.+)'")
def osm_set_import_file(context, osm_file):
assert context.import_data is None, \
"Import file cannot be used together with inline data."
pfile = Path(osm_file)
if pfile.is_absolute():
context.import_file = pfile
else:
basedir = context.user_args.test_data_dir or Path(context.feature.filename).parent
context.import_file = (basedir / osm_file).resolve()
@given("the OSM data")
def osm_define_data(context):
assert context.import_file is None, \
"Inline data cannot be used together with an import file."
if context.text.strip():
context.append_osm_data(context.nodes.add_coordinates(context.text.split('\n')))
else:
context.append_osm_data([])
@given("the OSM data format string")
def osm_define_data(context):
assert context.import_file is None, \
"Inline data cannot be used together with an import file."
data = eval('f"""' + context.text + '"""')
context.append_osm_data(context.nodes.add_coordinates(data.split('\n')))
@given("the (?P<step>[0-9.]+ )?grid(?: with origin (?P<origin_x>[0-9.-]+) (?P<origin_y>[0-9.-]+))?")
def osm_define_node_grid(context, step, origin_x, origin_y):
step = float(step.strip()) if step else 0.1
x = float(origin_x) if origin_x else 20.0
y = float(origin_y) if origin_y else 20.0
assert x > -180.0 and x < 180.0
assert y > -90.0 and y < 90.0
context.nodes.set_grid([context.table.headings] + [list(h) for h in context.table], step, x, y)
context.append_osm_data(context.nodes.get_as_opl())
################### Steps: Style file ##############################
@given("the style file '(?P<style>.+)'")
def setup_style_file(context, style):
sfile = Path(style)
if sfile.is_absolute():
assert sfile.is_file()
elif context.user_args.style_data_dir is not None \
and (context.user_args.style_data_dir / sfile).is_file():
sfile = context.user_args.style_data_dir / sfile
elif context.user_args.test_data_dir is not None \
and (context.user_args.test_data_dir / sfile).is_file():
sfile = context.user_args.test_data_dir / sfile
else:
sfile = Path(context.feature.filename).parent / sfile
assert sfile.is_file()
context.osm2pgsql_params['-S'] = str(sfile.resolve())
@given("the lua style")
def setup_style_inline(context):
outfile = context.workdir / 'inline_style.lua'
outfile.write_text(context.text)
context.osm2pgsql_params['-S'] = str(outfile)
################### Steps: Running osm2pgsql #######################
@when(r"running osm2pgsql(?P<output> \w+)?(?: with parameters)?")
def execute_osm2pgsql(context, output):
assert output in (' flex', ' pgsql', ' null', None)
if output is not None:
context.osm2pgsql_params['-O'] = output.strip()
if output == ' pgsql' and '-S' not in context.osm2pgsql_params:
context.osm2pgsql_params['-S'] = '{STYLE_DATA_DIR}/default.style'
cmdline = [context.user_args.osm2pgsql_binary]
test_dir = (context.user_args.test_data_dir or Path('.')).resolve()
style_dir = (context.user_args.style_data_dir or Path('.')).resolve()
def _template(param):
return param.replace('{TEST_DATA_DIR}', str(test_dir))\
.replace('{STYLE_DATA_DIR}', str(style_dir))\
.replace('{TEST_DB}', context.user_args.test_db)
if context.table:
assert not any('<' in h for h in context.table.headings), \
"Substition in the first line of a table are not supported."
cmdline.extend(_template(h) for h in context.table.headings if h)
for row in context.table:
cmdline.extend(_template(c) for c in row if c)
for k, v in context.osm2pgsql_params.items():
if k not in cmdline:
cmdline.extend((k, _template(v)))
if not context.user_args.test_tablespace\
and any(p.startswith('--tablespace') for p in cmdline):
context.scenario.skip('tablespace tablespacetest not available')
return
if context.import_data is not None:
data_stdin = '\n'.join(sorted(context.import_data.values(), key=opl_sort)).encode('utf-8')
context.import_data = None
cmdline.extend(('-r', 'opl', '-'))
else:
assert context.import_file is not None, "No input data given."
cmdline.append(str(context.import_file))
context.import_file = None
data_stdin = None
proc = Popen(cmdline, cwd=str(context.workdir),
stdin=PIPE, stdout=PIPE, stderr=PIPE)
outdata = proc.communicate(input=data_stdin)
context.osm2pgsql_cmdline = ' '.join(cmdline)
context.osm2pgsql_outdata = [d.decode('utf-8').replace('\\n', '\n') for d in outdata]
context.osm2pgsql_returncode = proc.returncode
context.osm2pgsql_params = {'-d': context.user_args.test_db}
@then("execution is successful")
def osm2pgsql_check_success(context):
assert context.osm2pgsql_returncode == 0, \
f"osm2pgsql failed with error code {context.osm2pgsql_returncode}.\n"\
f"Command line: {context.osm2pgsql_cmdline}\n"\
f"Output:\n{context.osm2pgsql_outdata[0]}\n{context.osm2pgsql_outdata[1]}\n"
@then(r"execution fails(?: with return code (?P<expected>\d+))?")
def osm2pgsql_check_failure(context, expected):
retcode = context.osm2pgsql_returncode
assert retcode != 0, "osm2pgsql unexpectedly succeeded"
if expected:
assert retcode == int(expected), \
f"osm2pgsql failed with return code {retcode} instead of {expected}\n"\
f"Output:\n{context.osm2pgsql_outdata[0]}\n{context.osm2pgsql_outdata[1]}\n"
@then(r"the (?P<kind>\w+) output contains")
def check_program_output(context, kind):
if kind == 'error':
s = context.osm2pgsql_outdata[1]
elif kind == 'standard':
s = context.osm2pgsql_outdata[0]
else:
assert not "Expect one of error, standard"
for line in context.text.split('\n'):
line = line.strip()
assert line in s,\
f"Output '{line}' not found in {kind} output:\n{s}\n"
################### Steps: Running Replication #####################
@given("the replication service at (?P<base_url>.*)")
def setup_replication_mock(context, base_url):
if osm2pgsql_replication is None:
context.scenario.skip("Replication binary not available. Skip.")
return
if context.table:
state_infos = \
[OsmosisState(int(row[0]),
dt.datetime.strptime(row[1], '%Y-%m-%dT%H:%M:%SZ').replace(tzinfo=dt.timezone.utc))
for row in context.table]
else:
state_infos = []
osm2pgsql_replication.ReplicationServer = ReplicationServerMock(base_url, state_infos)
@given("the URL (?P<base_url>.*) returns")
def mock_url_response(context, base_url):
context.urlrequest_responses[base_url] = context.text
@when("running osm2pgsql-replication")
def execute_osm2pgsql_replication(context):
assert osm2pgsql_replication is not None
assert osm2pgsql_replication.ReplicationServer is not None
cmdline = []
test_dir = (context.user_args.test_data_dir or Path('.')).resolve()
style_dir = (context.user_args.style_data_dir or Path('.')).resolve()
def _template(param):
return param.replace('{TEST_DATA_DIR}', str(test_dir))\
.replace('{STYLE_DATA_DIR}', str(style_dir))\
.replace('{TEST_DB}', context.user_args.test_db)
if context.table:
assert not any('<' in h for h in context.table.headings), \
"Substition in the first line of a table are not supported."
cmdline.extend(_template(h) for h in context.table.headings if h)
for row in context.table:
cmdline.extend(_template(c) for c in row if c)
if '-d' not in cmdline and '--database' not in cmdline:
cmdline.extend(('-d', context.user_args.test_db))
if cmdline[0] == 'update':
cmdline.extend(('--osm2pgsql-cmd', context.user_args.osm2pgsql_binary))
if '--' not in cmdline:
cmdline.extend(('--', '-S', str(style_dir / 'default.style')))
serr = io.StringIO()
log_handler = logging.StreamHandler(serr)
osm2pgsql_replication.LOG.addHandler(log_handler)
with contextlib.redirect_stdout(io.StringIO()) as sout:
context.osm2pgsql_returncode = osm2pgsql_replication.main(cmdline)
osm2pgsql_replication.LOG.removeHandler(log_handler)
context.osm2pgsql_outdata = [sout.getvalue(), serr.getvalue()]
################### Steps: Inspect database ########################
@given("the SQL statement (?P<sql>.+)")
def db_define_sql_statement(context, sql):
context.sql_statements[sql] = context.text
@then("there are (?P<exists>no )?tables (?P<tables>.+)")
def db_table_existance(context, exists, tables):
for table in tables.split(','):
table = table.strip()
if exists == 'no ':
assert not context.table_exists(table), f"Table '{table}' unexpectedly found"
else:
assert context.table_exists(table), f"Table '{table}' not found"
@then("table (?P<table>.+) contains(?P<exact> exactly)?")
def db_check_table_content(context, table, exact):
context.execute_steps("Then execution is successful")
assert context.table_exists(table), f"Table {table} not found in database."
context.check_select(sql.Identifier(*table.split('.', 1)), exact is not None)
@then("table (?P<table>.+) doesn't contain")
def db_check_table_content_negative(context, table):
context.execute_steps("Then execution is successful")
assert context.table_exists(table), f"Table {table} not found in database."
context.check_select_not_contained(sql.Identifier(*table.split('.', 1)))
@then(r"table (?P<table>.+) has (?P<row_num>\d+) rows?")
def db_table_row_count(context, table, row_num):
context.execute_steps("Then execution is successful")
assert context.table_exists(table), f"Table {table} not found in database."
query = sql.SQL("SELECT count(*) FROM {}").format(sql.Identifier(*table.split('.', 2)))
for res in context.db.execute(query):
assert res[0] == int(row_num),\
f"Table {table}: expected {row_num} rows, got {res[0]}"
@then("statement (?P<stmt>.+) returns(?P<exact> exactly)?")
def db_check_table_content(context, stmt, exact):
context.execute_steps("Then execution is successful")
assert stmt in context.sql_statements
context.check_select(sql.SQL(f"({context.sql_statements[stmt]}) _statement_sql"),
exact is not None)
################### Context ########################################
class Osm2pgsqlContext(Context):
def __init__(self, runner, args):
super().__init__(runner)
self.user_args = args
def connect_db(self, name=None):
dbname = name or self.user_args.test_db
return psycopg.connect(dbname=dbname, autocommit=True)
def drop_db(self, name=None, recreate_template=None):
db = sql.Identifier(name or self.user_args.test_db)
with self.connect_db('postgres') as conn:
conn.execute(sql.SQL('DROP DATABASE IF EXISTS {}').format(db))
if recreate_template == 'default':
conn.execute(sql.SQL('CREATE DATABASE {}').format(db))
elif recreate_template:
conn.execute(sql.SQL('CREATE DATABASE {} WITH TEMPLATE {}')
.format(db, sql.Identifier(recreate_template)))
def append_osm_data(self, lines, include_untagged=None):
if self.import_data is None:
self.import_data = {}
for line in lines:
if (l := line.strip()):
self.import_data[l.split(' ', 1)[0]] = l
def table_exists(self, table):
sql_params = table.split('.', 1) if '.' in table else ('public', table)
sql = 'SELECT count(*) FROM {} WHERE schemaname = %s AND tablename = %s'
for look_in in ('pg_tables', 'pg_views'):
for res in self.db.execute(sql.format('pg_tables'), sql_params):
if res[0] == 1:
return True
return False
def check_select(self, from_clause, exact):
rows = [ResultCompare(h, self.nodes) for h in self.table.headings]
lines = set(range(0, len(self.table.rows)))
query = sql.SQL('SELECT {} FROM {}').format(
sql.SQL(', '.join(f"({r.as_select()}) as c{i}" for i, r in enumerate(rows))),
from_clause)
table_content = ''
unexpected_rows = []
for row in self.db.execute(query):
table_content += f"\n{row}"
for i in lines:
for attr, expected, value in zip(rows, self.table[i], row):
if not attr.equals(expected, value):
break
else:
lines.remove(i)
break
else:
if exact:
unexpected_rows.append(str(row))
assert not lines, \
"Rows not found:\n" \
+ '\n'.join(str(self.table[i]) for i in lines) \
+ "\nTable content:\n" \
+ table_content
assert not unexpected_rows, \
"Unexpected rows found:\n" + '\n'.join(unexpected_rows)\
+ "\nTable content:\n" \
+ table_content
def check_select_not_contained(self, from_clause):
rows = [ResultCompare(h, self.nodes) for h in self.table.headings]
lines = set(range(0, len(self.table.rows)))
query = sql.SQL('SELECT {} FROM {}').format(
sql.SQL(', '.join(f"({r.as_select()}) as c{i}" for i, r in enumerate(rows))),
from_clause)
table_content = ''
matching_rows = []
for row in self.db.execute(query):
table_content += f"\n{row}"
for i in lines:
for attr, expected, value in zip(rows, self.table[i], row):
if not attr.equals(expected, value):
break
else:
matching_rows.append(str(row))
break
assert not matching_rows, \
"Matching rows found:\n" + '\n'.join(matching_rows)\
+ "\nFull table content:\n" \
+ table_content
#################### runner and main ###############################
class Osm2pgsqlRunner(ModelRunner):
def __init__(self, config, args):
super().__init__(config)
self.feature_locations = runner_util.collect_feature_locations(args.features)
self.hooks = {
'before_all' : hook_before_all,
'before_scenario': hook_before_scenario
}
self.context = Osm2pgsqlContext(self, args)
def run(self):
features = runner_util.parse_features(self.feature_locations)
self.features.extend(features)
stream_opener = StreamOpener(stream=sys.stdout)
self.formatters = [PrettyFormatter(stream_opener, self.config)]
return self.run_model()
def get_parser():
parser = ArgumentParser(description=__doc__,
prog='osm2pgsql-test-style',
formatter_class=RawDescriptionHelpFormatter)
parser.add_argument('features', nargs='+',
help='Feature files or paths')
parser.add_argument('--osm2pgsql-binary',
help='osm2pgsql binary to use for testing (default: osm2pgsql)')
parser.add_argument('--test-data-dir', type=Path,
help='(optional) directory to search for test data')
parser.add_argument('--style-data-dir', type=Path,
help='(optional) directory to search for style files')
parser.add_argument('--test-db', default='osm2pgsql-test',
help='Name of database to use for testing (default: osm2pgsql-test)')
parser.add_argument('--template-test-db', default='osm2pgsql-test-template',
help='Name of database to use for creating the template db '
'(default: osm2pgsql-test-template)')
parser.add_argument('--keep-test-db', action='store_true',
help='Keep the test database around after tests are done')
parser.add_argument('--test-tablespace', default='auto', choices=['yes', 'no', 'auto'],
help='Include tests requiring a tablespace')
parser.add_argument('--test-proj', default='auto', choices=['yes', 'no', 'auto'],
help='Include tests requiring the proj library')
return parser
def main(prog_args=None):
parser = get_parser()
try:
args = parser.parse_args(args=prog_args)
except SystemExit:
return 1
if args.osm2pgsql_binary is None:
args.osm2pgsql_binary = 'osm2pgsql'
else:
args.osm2pgsql_binary = str(Path(args.osm2pgsql_binary).resolve())
config = Configuration(command_args=[])
config.show_skipped = False
runner = Osm2pgsqlRunner(config, args)
failed = runner.run()
if runner.undefined_steps:
LOG.error('Error in feature definition. The following steps are unknown:\n - '
+ '\n - '.join(f"{s.keyword} {s.name}" for s in runner.undefined_steps))
return int(failed)
if __name__ == '__main__':
retcode = main()
try:
pass
except Exception as ex:
LOG.fatal("Exception during execution: %s", ex)
retcode = 3
sys.exit(retcode)