tests/test_report_cli.py

# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
from __future__ import annotations

import subprocess
import sys
from pathlib import Path

import numpy as np
import pytest

from iohmm_evac.cli import main
from iohmm_evac.dgp.simulator import simulate
from iohmm_evac.io import write_results
from iohmm_evac.params import SimulationConfig


@pytest.fixture
def simulation(tmp_path: Path) -> Path:
    config = SimulationConfig(n_households=50, n_hours=24, seed=0)
    rng = np.random.default_rng(config.seed)
    result = simulate(config, rng)
    out = tmp_path / "sim.parquet"
    write_results(result, out)
    return out


def test_report_all_writes_four_pngs(tmp_path: Path, simulation: Path) -> None:
    out_dir = tmp_path / "figures"
    rc = main(
        [
            "report",
            "all",
            "--input",
            str(simulation),
            "--output-dir",
            str(out_dir),
        ]
    )
    assert rc == 0
    expected = ["occupancy.png", "departures.png", "emissions.png", "trajectories.png"]
    for name in expected:
        target = out_dir / name
        assert target.exists(), f"Missing: {target}"
        assert target.stat().st_size > 0


def test_report_occupancy_default_output(simulation: Path) -> None:
    rc = main(["report", "occupancy", "--input", str(simulation)])
    assert rc == 0
    expected = simulation.parent / f"{simulation.stem}.occupancy.png"
    assert expected.exists()
    assert expected.stat().st_size > 0


def test_report_departures_explicit_output(tmp_path: Path, simulation: Path) -> None:
    out = tmp_path / "depart.png"
    rc = main(
        [
            "report",
            "departures",
            "--input",
            str(simulation),
            "--output",
            str(out),
        ]
    )
    assert rc == 0
    assert out.exists()
    assert out.stat().st_size > 0


def test_report_emissions(tmp_path: Path, simulation: Path) -> None:
    out = tmp_path / "em.png"
    rc = main(["report", "emissions", "--input", str(simulation), "--output", str(out)])
    assert rc == 0
    assert out.exists()


def test_report_trajectories_with_ids(tmp_path: Path, simulation: Path) -> None:
    out = tmp_path / "traj.png"
    rc = main(
        [
            "report",
            "trajectories",
            "--input",
            str(simulation),
            "--household-ids",
            "0,5,17",
            "--output",
            str(out),
        ]
    )
    assert rc == 0
    assert out.exists()


def test_report_trajectories_default_ids(tmp_path: Path, simulation: Path) -> None:
    out = tmp_path / "traj.png"
    rc = main(
        [
            "report",
            "trajectories",
            "--input",
            str(simulation),
            "--output",
            str(out),
        ]
    )
    assert rc == 0
    assert out.exists()


def test_report_all_subprocess(tmp_path: Path, simulation: Path) -> None:
    out_dir = tmp_path / "figures-sub"
    proc = subprocess.run(
        [
            sys.executable,
            "-m",
            "iohmm_evac.cli",
            "report",
            "all",
            "--input",
            str(simulation),
            "--output-dir",
            str(out_dir),
        ],
        capture_output=True,
        text=True,
        env={"MPLBACKEND": "Agg", **_inherit_env()},
    )
    assert proc.returncode == 0, proc.stderr
    for name in ["occupancy.png", "departures.png", "emissions.png", "trajectories.png"]:
        assert (out_dir / name).exists()


def test_report_household_ids_invalid_token() -> None:
    with pytest.raises(SystemExit):
        main(
            [
                "report",
                "trajectories",
                "--input",
                "/tmp/does-not-matter.parquet",
                "--household-ids",
                "abc",
            ]
        )


def test_report_household_ids_empty() -> None:
    with pytest.raises(SystemExit):
        main(
            [
                "report",
                "trajectories",
                "--input",
                "/tmp/does-not-matter.parquet",
                "--household-ids",
                "",
            ]
        )


def _inherit_env() -> dict[str, str]:
    import os

    keep = ("PATH", "HOME", "USER", "VIRTUAL_ENV", "PYTHONPATH")
    return {k: os.environ[k] for k in keep if k in os.environ}