tests/test_fit_cli.py

# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""Subprocess-based CLI smoke tests for fit, diagnose, and report subcommands."""

from __future__ import annotations

import subprocess
import sys
from pathlib import Path

import numpy as np
import pytest

from iohmm_evac.dgp.simulator import simulate
from iohmm_evac.io import write_results
from iohmm_evac.params import SimulationConfig


def _inherit_env() -> dict[str, str]:
    import os

    keep = ("PATH", "HOME", "USER", "VIRTUAL_ENV", "PYTHONPATH")
    return {k: os.environ[k] for k in keep if k in os.environ}


@pytest.fixture
def tiny_simulation(tmp_path: Path) -> Path:
    config = SimulationConfig(n_households=80, n_hours=24, seed=0)
    rng = np.random.default_rng(config.seed)
    result = simulate(config, rng)
    out = tmp_path / "tiny.parquet"
    write_results(result, out)
    return out


def _run(args: list[str]) -> subprocess.CompletedProcess[str]:
    return subprocess.run(
        [sys.executable, "-m", "iohmm_evac.cli", *args],
        capture_output=True,
        text=True,
        env={"MPLBACKEND": "Agg", **_inherit_env()},
    )


def test_fit_writes_bundle_files(tmp_path: Path, tiny_simulation: Path) -> None:
    fit_dir = tmp_path / "fit"
    proc = _run(
        [
            "fit",
            "--input",
            str(tiny_simulation),
            "--output",
            str(fit_dir),
            "--init",
            "truth",
            "--max-iter",
            "3",
            "--quiet",
        ]
    )
    assert proc.returncode == 0, proc.stderr
    for name in (
        "theta.toml",
        "log_likelihood_trace.parquet",
        "posterior_states.parquet",
        "metadata.toml",
    ):
        assert (fit_dir / name).exists(), f"missing {name}"


def test_diagnose_recovery_writes_toml(tmp_path: Path, tiny_simulation: Path) -> None:
    fit_dir = tmp_path / "fit"
    proc = _run(
        [
            "fit",
            "--input",
            str(tiny_simulation),
            "--output",
            str(fit_dir),
            "--init",
            "truth",
            "--max-iter",
            "3",
            "--quiet",
        ]
    )
    assert proc.returncode == 0, proc.stderr
    proc = _run(
        [
            "diagnose",
            "recovery",
            "--fit",
            str(fit_dir),
            "--truth",
            str(tiny_simulation),
        ]
    )
    assert proc.returncode == 0, proc.stderr
    assert (fit_dir / "recovery.toml").exists()


def test_report_recovery_subcommands_render(tmp_path: Path, tiny_simulation: Path) -> None:
    fit_dir = tmp_path / "fit"
    _run(
        [
            "fit",
            "--input",
            str(tiny_simulation),
            "--output",
            str(fit_dir),
            "--init",
            "truth",
            "--max-iter",
            "3",
            "--quiet",
        ]
    )
    confusion_png = tmp_path / "confusion.png"
    proc = _run(
        [
            "report",
            "recovery-confusion",
            "--fit",
            str(fit_dir),
            "--truth",
            str(tiny_simulation),
            "--output",
            str(confusion_png),
        ]
    )
    assert proc.returncode == 0, proc.stderr
    assert confusion_png.exists()

    param_png = tmp_path / "params.png"
    proc = _run(
        [
            "report",
            "parameter-recovery",
            "--fit",
            str(fit_dir),
            "--truth",
            str(tiny_simulation),
            "--output",
            str(param_png),
        ]
    )
    assert proc.returncode == 0, proc.stderr
    assert param_png.exists()

    ll_png = tmp_path / "ll.png"
    proc = _run(["report", "ll-trace", "--fit", str(fit_dir), "--output", str(ll_png)])
    assert proc.returncode == 0, proc.stderr
    assert ll_png.exists()

    proc = _run(["report", "fit-summary", "--fit", str(fit_dir)])
    assert proc.returncode == 0, proc.stderr
    assert "best restart" in proc.stdout