#!/usr/bin/env python3 # SPDX-License-Identifier: GPL-2.0-or-later # # This file is part of osm2pgsql (https://osm2pgsql.org/). # # Copyright (C) 2025 by the osm2pgsql developer community. # For a full list of authors see the git log. """ Test runner for BDD-style integration tests. See osm2pgsql manual for more information on osm2pgsql style testing. """ import logging import sys import tempfile import math import re import os import contextlib import json import datetime as dt from decimal import Decimal from subprocess import Popen, PIPE from argparse import ArgumentParser, RawDescriptionHelpFormatter from pathlib import Path import importlib.util import io from importlib.machinery import SourceFileLoader from behave import given, when, then, use_step_matcher, use_fixture, fixture from behave.runner import ModelRunner, Context from behave.formatter.base import StreamOpener from behave.formatter.pretty import PrettyFormatter from behave import runner_util from behave.configuration import Configuration LOG = logging.getLogger() import psycopg from psycopg import sql use_step_matcher('re') OBJECT_ORDER = {'n': 1, 'w': 2, 'r': 3} def opl_sort(line): oid = line.split(' ', 1)[0] return OBJECT_ORDER[oid[0]], int(oid[1:]) #################### Replication mock ############################## class ReplicationServerMock: def __init__(self, base_url, state_infos): self.expected_base_url = base_url self.state_infos = state_infos def __call__(self, base_url): assert base_url == self.expected_base_url,\ f"Wrong replication service called. Expected '{self.expected_base_url}', got '{base_url}'" return self def get_state_info(self, seq=None, retries=2): assert self.state_infos, 'Replication mock not properly set up' if seq is None: return self.state_infos[-1] for info in self.state_infos: if info.sequence == seq: return info return None def timestamp_to_sequence(self, timestamp, balanced_search=False): assert self.state_infos, 'Replication mock not properly set up' if timestamp < self.state_infos[0].timestamp: return self.state_infos[0].sequence prev = self.state_infos[0] for info in self.state_infos: if timestamp >= prev.timestamp and timestamp < info.timestamp: return prev.sequence prev = info return prev.sequence def apply_diffs(self, handler, start_id, max_size=1024, idx="", simplify=True): if start_id > self.state_infos[-1].sequence: return None numdiffs = int((max_size + 1023)/1024) return min(self.state_infos[-1].sequence, start_id + numdiffs - 1) # Replication module is optional _repfl_spec = importlib.util.spec_from_loader( 'osm2pgsql_replication', SourceFileLoader('osm2pgsql_replication', str(Path(__file__, '..', 'osm2pgsql-replication').resolve()))) if _repfl_spec: osm2pgsql_replication = importlib.util.module_from_spec(_repfl_spec) _repfl_spec.loader.exec_module(osm2pgsql_replication) from osmium.replication.server import OsmosisState else: osm2pgsql_replication = None #################### hooks ######################################### def hook_before_all(context): context.config.setup_logging(logging.INFO) # Feature check: table spaces if context.user_args.test_tablespace == 'auto': with context.connect_db('postgres') as conn: with conn.cursor() as cur: cur.execute("""SELECT spcname FROM pg_tablespace WHERE spcname = 'tablespacetest'""") context.user_args.test_tablespace = cur.rowcount > 0 LOG.info('Check if tablespaces are available: %s', 'yes' if context.user_args.test_tablespace else 'no') else: context.user_args.test_tablespace = context.user_args.test_tablespace == 'yes' # Test that osm2pgsql works. proc = Popen([context.user_args.osm2pgsql_binary, '--version'], stdout=PIPE, stderr=PIPE) _, serr = proc.communicate() osm2pgsql_version = serr.decode('utf-8') if proc.returncode != 0: LOG.critical("Could not run osm2pgsql. Error:\n%s", serr) LOG.critical("osm2pgsql binary used: %s", context.user_args.osm2pgsql_binary) raise RuntimeError('Error running osm2pgsql') LOG.info('Check if proj is available: %s', 'yes' if context.user_args.test_proj else 'no') # Feature check: proj if context.user_args.test_proj == 'auto': context.user_args.test_proj = 'Proj [disabled]' not in osm2pgsql_version else: context.user_args.test_proj = context.user_args.test_proj == 'yes' use_fixture(template_test_db, context) def hook_before_scenario(context, scenario): if 'config.have_proj' in scenario.tags and not context.user_args.test_proj: scenario.skip("Generic proj library not configured.") context.db = use_fixture(test_db, context) context.import_file = None context.import_data = None context.osm2pgsql_params = {'-d': context.user_args.test_db} context.osm2pgsql_returncode = None context.workdir = use_fixture(working_directory, context) context.nodes = NodeStore() context.sql_statements = {} context.urlrequest_responses = {} osm2pgsql_replication.ReplicationServer = None def _mock_urlopen(request): if not request.full_url in context.urlrequest_responses: raise urllib.error.URLError('Unknown URL') return contextlib.closing(io.BytesIO(context.urlrequest_responses[request.full_url].encode('utf-8'))) osm2pgsql_replication.urlrequest.urlopen = _mock_urlopen #################### fixtures ###################################### @fixture def template_test_db(context, **kwargs): context.drop_db(context.user_args.template_test_db, recreate_template='default') with context.connect_db(context.user_args.template_test_db) as conn: conn.execute('CREATE EXTENSION postgis') conn.execute('CREATE EXTENSION hstore') yield context.user_args.template_test_db context.drop_db(context.user_args.template_test_db) @fixture def test_db(context, **kwargs): context.drop_db(recreate_template=context.user_args.template_test_db) with context.connect_db() as conn: yield conn if not context.user_args.keep_test_db: context.drop_db() @fixture def working_directory(context, **kwargs): with tempfile.TemporaryDirectory() as tmpdir: yield Path(tmpdir) ################### Node location creation ######################### class NodeStore: grid = {} def set_grid(self, lines, grid_step, origin_x, origin_y): self.grid = {} origin_y -= grid_step * (len(lines) - 1) ndigits = 1 step = grid_step while step < 0: ndigits += 1 step /= 10 for y, line in enumerate(lines): for x, pt in enumerate(line): if pt.isdigit(): self.grid[int(pt)] = (round(origin_x + x * grid_step, ndigits), round(origin_y + y * grid_step, ndigits)) def get_as_opl(self): return [f"n{i} x{x} y{y}" for i, (x, y) in self.grid.items()] def add_coordinates(self, lines): for line in lines: if line.startswith('n') and ' x' not in line: nid = int(line.split(' ', 1)[0][1:]) assert nid in self.grid, \ f"OPL error. Node {nid} has no coordinates and is not in grid." x, y = self.grid[nid] yield f"{line} x{x} y{y}" else: yield line def parse_point(self, pt): pt = pt.strip() if ' ' in pt: return list(map(float, pt.split(' ', 1))) return self.grid[int(pt)] ################### ResultComparison ############################### class ResultCompare: def __init__(self, heading, nodes): self.nodes = nodes if '!' in heading: self.name, self.fmt = heading.rsplit('!', 1) if self.fmt.startswith(':'): self.compare = self._intcompare_fmt elif self.fmt.startswith('~'): if self.fmt.endswith('%'): rel_tol = float(self.fmt[1:-1]) / 100 self.compare = lambda exp, val: math.isclose(float(exp), val, rel_tol=rel_tol) else: abs_tol = float(self.fmt[1:]) self.compare = lambda exp, val: math.isclose(float(exp), val, abs_tol=abs_tol) else: self.compare = getattr(self, f"_compare_{self.fmt}", None) assert self.compare is not None, f"Unknown formatter {self.fmt}" else: self.name = heading self.fmt = None self.compare = lambda exp, val: str(val) == exp def as_select(self): if self.fmt == 'geo': return f"ST_AsText({self.name})" return self.name def equals(self, expected, value): if expected == 'NULL': return value is None return self.compare(expected, value) def _intcompare_fmt(self, expected, value): return expected == f"{{{self.fmt}}}".format(value) def _compare_i(self, expected, value): return expected.lower() == str(value).lower() def _compare_re(self, expected, value): return re.fullmatch(expected, str(value)) is not None def _compare_substr(self, expected, value): return expected in str(value) def _compare_json(self, expected, value): return json.loads(expected) == value def _compare_geo(self, expected, value): m = re.fullmatch(r'([A-Z]+)\((.*)\)', value) return self._eq_geom(expected, m[1], m[2]) if m else False def _eq_geom(self, bdd_geom, pg_type, pg_coords): # MULTI* geometries if bdd_geom.startswith('[') and bdd_geom.endswith(']'): bdd_parts = bdd_geom[1:-1].split(';') pg_parts = pg_coords[1:-1].split('),(') return pg_type.startswith('MULTI') \ and len(bdd_parts) == len(pg_parts) \ and all(self._eq_geom(b.strip(), pg_type[5:], g.strip()) for b, g in zip(bdd_parts, pg_parts)) # GEOMETRYCOLLECTIONS if bdd_geom.startswith('{') and bdd_geom.endswith('}'): bdd_parts = bdd_geom[1:-1].split(';') pg_parts = list(map(lambda s: re.fullmatch(r'([A-Z]+)\(([^A-Z]*)\)', s), re.findall('[A-Z]+[^A-Z]+[^,A-Z]', pg_coords))) return pg_type.startswith('GEOMETRYCOLLECTION')\ and len(bdd_parts) == len(pg_parts)\ and all(g is not None and self._eq_geom(b.strip(), g[1], g[2]) for b, g in zip(bdd_parts, pg_parts)) # POINT if ',' not in bdd_geom: return pg_type == 'POINT' and self._eq_point(bdd_geom, pg_coords) # LINESTRING if '(' not in bdd_geom: return pg_type == 'LINESTRING' \ and all(self._eq_point(b, p) for b, p in zip((g.strip() for g in bdd_geom.split(',')), (g.strip() for g in pg_coords.split(',')))) # POLYGON if pg_type != 'POLYGON': return False # Polygon comparison is tricky because the polygons don't necessarily # end at the same point or have the same winding order. # Brute force all possible variants of the expected polygon bdd_parts = re.findall(r'\([^)]+\)', bdd_geom) pg_parts = [g.strip() for g in pg_coords[1:-1].split('),(')] return len(bdd_parts) == len(pg_parts) \ and all(self._eq_ring(*parts) for parts in zip(bdd_parts, pg_parts)) def _eq_point(self, bdd_pt, pg_pt): exp_geom = self.nodes.parse_point(bdd_pt) pg_geom = list(map(float, pg_pt.split(' '))) return len(exp_geom) == len(pg_geom) \ and all(math.isclose(e, p, rel_tol=0.000001) for e, p in zip(exp_geom, pg_geom)) def _eq_ring(self, bdd_ring, pg_ring): bdd_pts = [g.strip() for g in bdd_ring[1:-1].split(',')] pg_pts = [g.strip() for g in pg_ring.split(',')] if bdd_pts[0] != bdd_pts[-1]: raise RuntimeError(f"Invalid polygon {bdd_geom}. " "First and last point need to be the same") if len(bdd_pts) != len(pg_pts): return False for line in (bdd_pts[:-1], bdd_pts[-1:0:-1]): for i in range(len(line)): if all(self._eq_point(p1, p2) for p1, p2 in zip(line[i:] + line[:i], pg_pts)): return True return False ################### Steps: Database setup ########################## @given("the database schema (?P.+)") def create_db_schema(context, schema): with context.db.cursor() as cur: cur.execute("CREATE SCHEMA " + schema) @when("deleting table (?P.+)") def delete_table(context, table): with context.db.cursor() as cur: cur.execute("DROP TABLE " + table) ################### Steps: OSM data ################################ @given("the input file '(?P.+)'") def osm_set_import_file(context, osm_file): assert context.import_data is None, \ "Import file cannot be used together with inline data." pfile = Path(osm_file) if pfile.is_absolute(): context.import_file = pfile else: basedir = context.user_args.test_data_dir or Path(context.feature.filename).parent context.import_file = (basedir / osm_file).resolve() @given("the OSM data") def osm_define_data(context): assert context.import_file is None, \ "Inline data cannot be used together with an import file." if context.text.strip(): context.append_osm_data(context.nodes.add_coordinates(context.text.split('\n'))) else: context.append_osm_data([]) @given("the OSM data format string") def osm_define_data(context): assert context.import_file is None, \ "Inline data cannot be used together with an import file." data = eval('f"""' + context.text + '"""') context.append_osm_data(context.nodes.add_coordinates(data.split('\n'))) @given("the (?P[0-9.]+ )?grid(?: with origin (?P[0-9.-]+) (?P[0-9.-]+))?") def osm_define_node_grid(context, step, origin_x, origin_y): step = float(step.strip()) if step else 0.1 x = float(origin_x) if origin_x else 20.0 y = float(origin_y) if origin_y else 20.0 assert x > -180.0 and x < 180.0 assert y > -90.0 and y < 90.0 context.nodes.set_grid([context.table.headings] + [list(h) for h in context.table], step, x, y) context.append_osm_data(context.nodes.get_as_opl()) ################### Steps: Style file ############################## @given("the style file '(?P