tests/test_cli.py

# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
from __future__ import annotations

import subprocess
import sys
import tomllib
from pathlib import Path

import pyarrow.parquet as pq
import pytest

from iohmm_evac.cli import build_parser, main


def _run_cli(args: list[str], cwd: Path | None = None) -> subprocess.CompletedProcess[str]:
    proc = subprocess.run(
        [sys.executable, "-m", "iohmm_evac.cli", *args],
        capture_output=True,
        text=True,
        cwd=cwd,
    )
    return proc


def test_parser_smoke() -> None:
    parser = build_parser()
    ns = parser.parse_args(["simulate", "--n-households", "10", "--n-hours", "5"])
    assert ns.command == "simulate"
    assert ns.n_households == 10


def test_main_scenarios_list(capsys: pytest.CaptureFixture[str]) -> None:
    rc = main(["scenarios", "list"])
    out = capsys.readouterr().out.splitlines()
    assert rc == 0
    assert set(out) == {"baseline", "early-warning", "targeted-messaging", "contraflow"}


def test_main_config_dump_returns_valid_toml(capsys: pytest.CaptureFixture[str]) -> None:
    rc = main(["config", "dump", "--scenario", "early-warning"])
    out = capsys.readouterr().out
    assert rc == 0
    parsed = tomllib.loads(out)
    assert parsed["timeline"]["voluntary_hour"] == 48
    assert parsed["timeline"]["mandatory_hour"] == 72


def test_main_simulate_writes_outputs(tmp_path: Path) -> None:
    out = tmp_path / "sim.parquet"
    rc = main(
        [
            "simulate",
            "--n-households",
            "50",
            "--n-hours",
            "12",
            "--seed",
            "7",
            "--output",
            str(out),
            "--quiet",
        ]
    )
    assert rc == 0
    assert out.exists()
    pop_path = tmp_path / "sim.population.parquet"
    tl_path = tmp_path / "sim.timeline.parquet"
    cfg_path = tmp_path / "sim.config.toml"
    assert pop_path.exists()
    assert tl_path.exists()
    assert cfg_path.exists()

    obs = pq.read_table(out)  # type: ignore[no-untyped-call]
    assert {"household_id", "t", "state", "departure", "displacement", "comm_count"}.issubset(
        set(obs.schema.names)
    )
    assert obs.num_rows == 50 * 13


def test_set_override_lands(tmp_path: Path) -> None:
    out = tmp_path / "sim.parquet"
    rc = main(
        [
            "simulate",
            "--n-households",
            "20",
            "--n-hours",
            "6",
            "--output",
            str(out),
            "--quiet",
            "--set",
            "transitions.ua_to_aw.beta_mand=99.0",
            "--set",
            "feedback.n_cap=42",
        ]
    )
    assert rc == 0
    cfg = tomllib.loads((tmp_path / "sim.config.toml").read_text())
    assert cfg["transitions"]["ua_to_aw"]["beta_mand"] == 99.0
    assert cfg["feedback"]["n_cap"] == 42


def test_subprocess_invocation_smoke(tmp_path: Path) -> None:
    out = tmp_path / "sim.parquet"
    proc = _run_cli(
        [
            "simulate",
            "--n-households",
            "30",
            "--n-hours",
            "8",
            "--seed",
            "1",
            "--output",
            str(out),
            "--quiet",
        ]
    )
    assert proc.returncode == 0, proc.stderr
    assert out.exists()


def test_set_rejects_bad_field() -> None:
    with pytest.raises(ValueError, match="Unknown field"):
        main(
            [
                "simulate",
                "--output",
                "/tmp/should-not-be-written.parquet",
                "--set",
                "bogus.path=1.0",
            ]
        )


def test_set_rejects_missing_equals() -> None:
    with pytest.raises(ValueError, match="KEY=VALUE"):
        main(["simulate", "--set", "transitions.ua_to_aw.beta_mand"])


def test_landfall_hour_alias(tmp_path: Path) -> None:
    out = tmp_path / "sim.parquet"
    rc = main(
        [
            "simulate",
            "--n-households",
            "20",
            "--landfall-hour",
            "10",
            "--output",
            str(out),
            "--quiet",
        ]
    )
    assert rc == 0
    cfg = tomllib.loads((tmp_path / "sim.config.toml").read_text())
    assert cfg["n_hours"] == 10


def test_config_file_loads(tmp_path: Path) -> None:
    cfg = tmp_path / "override.toml"
    cfg.write_text("[transitions.ua_to_aw]\nbeta_mand = 5.5\n[feedback]\nn_cap = 999\n")
    out = tmp_path / "sim.parquet"
    rc = main(
        [
            "simulate",
            "--n-households",
            "20",
            "--n-hours",
            "6",
            "--output",
            str(out),
            "--config",
            str(cfg),
            "--quiet",
        ]
    )
    assert rc == 0
    written = tomllib.loads((tmp_path / "sim.config.toml").read_text())
    assert written["transitions"]["ua_to_aw"]["beta_mand"] == 5.5
    assert written["feedback"]["n_cap"] == 999


def test_config_file_missing_raises(tmp_path: Path) -> None:
    with pytest.raises(FileNotFoundError):
        main(
            [
                "simulate",
                "--config",
                str(tmp_path / "missing.toml"),
                "--output",
                str(tmp_path / "out.parquet"),
            ]
        )


def test_verbose_flag(tmp_path: Path, capsys: pytest.CaptureFixture[str]) -> None:
    out = tmp_path / "sim.parquet"
    rc = main(
        [
            "simulate",
            "--n-households",
            "20",
            "--n-hours",
            "6",
            "--output",
            str(out),
            "--verbose",
        ]
    )
    assert rc == 0
    assert "final SH share" in capsys.readouterr().err