tests/test_report_summary.py

# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
from __future__ import annotations

from dataclasses import replace
from pathlib import Path

import numpy as np
import pytest

from iohmm_evac.cli import main
from iohmm_evac.dgp.simulator import simulate
from iohmm_evac.io import write_results
from iohmm_evac.report.loader import load_bundle
from iohmm_evac.report.summary import bundle_summary, format_summary
from iohmm_evac.scenarios import build_scenario


@pytest.fixture
def baseline_bundle_path(tmp_path: Path) -> Path:
    config = replace(build_scenario("baseline"), n_households=200, n_hours=120, seed=0)
    rng = np.random.default_rng(config.seed)
    result = simulate(config, rng)
    out = tmp_path / "sum.parquet"
    write_results(result, out)
    return out


def test_bundle_summary_keys(baseline_bundle_path: Path) -> None:
    bundle = load_bundle(baseline_bundle_path)
    metrics = bundle_summary(bundle)
    assert set(metrics.keys()) == {
        "share_sheltered_at_t48",
        "share_sheltered_at_landfall",
        "share_failed_evacuation",
        "share_evacuated_away",
        "share_sheltered_in_place",
        "peak_enroute_share",
        "peak_enroute_hour",
        "median_departure_hour",
    }


def test_bundle_summary_matches_simulationresult(tmp_path: Path) -> None:
    """The bundle-derived metrics should match the in-memory ones bit-for-bit."""
    config = replace(build_scenario("baseline"), n_households=300, n_hours=120, seed=3)
    rng = np.random.default_rng(config.seed)
    result = simulate(config, rng)
    inmem = result.summary()
    out = tmp_path / "match.parquet"
    write_results(result, out)
    derived = bundle_summary(load_bundle(out))
    for k, v in inmem.items():
        assert derived[k] == pytest.approx(v, abs=1e-9), k


def test_format_summary_lists_all_metrics(baseline_bundle_path: Path) -> None:
    metrics = bundle_summary(load_bundle(baseline_bundle_path))
    text = format_summary(metrics)
    for key in metrics:
        assert key in text
    assert "metric" in text


def test_report_summary_cli_smoke(
    baseline_bundle_path: Path, capsys: pytest.CaptureFixture[str]
) -> None:
    rc = main(["report", "summary", "--input", str(baseline_bundle_path)])
    captured = capsys.readouterr()
    assert rc == 0
    out = captured.out
    assert "share_sheltered_at_t48" in out
    assert "median_departure_hour" in out