src/iohmm_evac/cli.py

# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""Command-line entry point for the simulator.

Subcommands:

* ``iohmm-evac simulate ...`` — run a simulation and write Parquet outputs.
* ``iohmm-evac scenarios list`` — list the registered scenarios.
* ``iohmm-evac config dump`` — dump the resolved config as TOML.
* ``iohmm-evac report ...`` — render diagnostic plots from a saved run.
"""

from __future__ import annotations

import argparse
import sys
import tomllib
from collections.abc import Mapping, Sequence
from dataclasses import fields, is_dataclass, replace
from pathlib import Path
from typing import Any, get_type_hints

import numpy as np
import tomli_w

from iohmm_evac.bootstrap_cli import add_bootstrap_subparser, run_bootstrap_command
from iohmm_evac.dgp.simulator import simulate
from iohmm_evac.diagnostics.cli import add_diagnose_subparser, run_diagnose
from iohmm_evac.inference.cli import add_fit_subparser, run_fit
from iohmm_evac.io import write_results
from iohmm_evac.params import SimulationConfig, to_nested_dict
from iohmm_evac.report.cli import add_report_subparser, run_report
from iohmm_evac.scenarios import build_scenario, list_scenarios
from iohmm_evac.sweep_cli import add_sweep_subparser, run_sweep_command

__all__ = ["build_parser", "main"]


def _coerce(value: str, target_type: Any) -> Any:
    """Coerce a CLI string into the type of the target field."""
    origin = getattr(target_type, "__origin__", None)
    if target_type is bool:
        lowered = value.strip().lower()
        if lowered in {"true", "1", "yes", "y", "on"}:
            return True
        if lowered in {"false", "0", "no", "n", "off"}:
            return False
        msg = f"Cannot interpret {value!r} as bool"
        raise ValueError(msg)
    if target_type is int:
        return int(value)
    if target_type is float:
        return float(value)
    if target_type is str:
        return value
    if origin is tuple:
        # e.g. tuple[int, ...] or tuple[float, ...]
        inner = target_type.__args__[0]
        return tuple(_coerce(v.strip(), inner) for v in value.split(","))
    msg = f"Unsupported override target type: {target_type!r}"
    raise ValueError(msg)


def _replace_at_path(obj: Any, path: Sequence[str], new_value: Any) -> Any:
    """Return a copy of dataclass tree ``obj`` with ``path`` set to ``new_value``."""
    if not path:
        return new_value
    head, *rest = path
    if not is_dataclass(obj) or isinstance(obj, type):
        msg = f"Cannot descend into non-dataclass at path segment {head!r}"
        raise ValueError(msg)
    field_names = {f.name for f in fields(obj)}
    if head not in field_names:
        msg = f"Unknown field {head!r} on {type(obj).__name__}"
        raise ValueError(msg)
    current = getattr(obj, head)
    if rest:
        updated = _replace_at_path(current, rest, new_value)
    else:
        target_type = get_type_hints(type(obj)).get(head, type(current))
        updated = _coerce(new_value, target_type) if isinstance(new_value, str) else new_value
    return replace(obj, **{head: updated})


def _apply_set_overrides(config: SimulationConfig, overrides: list[str]) -> SimulationConfig:
    """Apply repeated ``--set key=value`` strings to ``config``."""
    for raw in overrides:
        if "=" not in raw:
            msg = f"--set expects KEY=VALUE, got {raw!r}"
            raise ValueError(msg)
        key, value = raw.split("=", 1)
        path = key.strip().split(".")
        config = _replace_at_path(config, path, value)
    assert isinstance(config, SimulationConfig)
    return config


def _apply_dict_overrides(obj: Any, data: Mapping[str, Any]) -> Any:
    """Recursively merge a mapping into the dataclass tree ``obj``."""
    if not is_dataclass(obj) or isinstance(obj, type):
        return data
    field_names = {f.name for f in fields(obj)}
    type_hints = get_type_hints(type(obj))
    updates: dict[str, Any] = {}
    for key, val in data.items():
        if key not in field_names:
            msg = f"Unknown field {key!r} on {type(obj).__name__}"
            raise ValueError(msg)
        current = getattr(obj, key)
        if isinstance(val, Mapping) and is_dataclass(current) and not isinstance(current, type):
            updates[key] = _apply_dict_overrides(current, val)
        elif isinstance(val, list):
            target_type = type_hints.get(key, type(current))
            origin = getattr(target_type, "__origin__", None)
            updates[key] = tuple(val) if origin is tuple else val
        else:
            updates[key] = val
    return replace(obj, **updates)


def _load_toml_config(path: Path) -> dict[str, Any]:
    with path.open("rb") as f:
        return tomllib.load(f)


def _resolve_config(args: argparse.Namespace) -> SimulationConfig:
    config = build_scenario(args.scenario)

    if args.config:
        cfg_path = Path(args.config)
        if not cfg_path.exists():
            msg = f"Config file not found: {cfg_path}"
            raise FileNotFoundError(msg)
        config = _apply_dict_overrides(config, _load_toml_config(cfg_path))

    direct_overrides: dict[str, Any] = {}
    if args.n_households is not None:
        direct_overrides["n_households"] = args.n_households
    if args.n_hours is not None:
        direct_overrides["n_hours"] = args.n_hours
    if args.landfall_hour is not None:
        direct_overrides["n_hours"] = args.landfall_hour
    if args.seed is not None:
        direct_overrides["seed"] = args.seed
    if direct_overrides:
        config = replace(config, **direct_overrides)

    if args.n_cap is not None:
        config = replace(config, feedback=replace(config.feedback, n_cap=args.n_cap))
    if args.shelter_capacity is not None:
        config = replace(
            config, feedback=replace(config.feedback, shelter_capacity=args.shelter_capacity)
        )

    if args.set:
        config = _apply_set_overrides(config, args.set)

    return config


def _add_simulate_flags(p: argparse.ArgumentParser) -> None:
    g_scenario = p.add_argument_group("scenario")
    g_scenario.add_argument(
        "--scenario",
        choices=list_scenarios(),
        default="baseline",
        help="Predefined scenario (default: baseline).",
    )
    g_scenario.add_argument(
        "--config",
        default=None,
        help="Optional TOML file with parameter overrides applied on top of the scenario.",
    )

    g_size = p.add_argument_group("size & seed")
    g_size.add_argument(
        "--n-households", type=int, default=None, help="Override N (default 10000)."
    )
    g_size.add_argument("--n-hours", type=int, default=None, help="Override T (default 120).")
    g_size.add_argument("--landfall-hour", type=int, default=None, help="Alias for --n-hours.")
    g_size.add_argument("--seed", type=int, default=0, help="RNG seed (default 0).")

    g_phys = p.add_argument_group("physics overrides")
    g_phys.add_argument("--n-cap", type=int, default=None, help="Override feedback N_cap.")
    g_phys.add_argument(
        "--shelter-capacity", type=int, default=None, help="Override shelter capacity."
    )

    g_out = p.add_argument_group("output")
    g_out.add_argument(
        "--output",
        type=Path,
        default=Path("./output/simulation.parquet"),
        help="Output Parquet path (default ./output/simulation.parquet).",
    )

    p.add_argument(
        "--set",
        action="append",
        default=[],
        metavar="KEY=VALUE",
        help="Fine-grained override, e.g. --set transitions.ua_to_aw.beta_mand=3.0 (repeatable).",
    )

    verbosity = p.add_mutually_exclusive_group()
    verbosity.add_argument("--quiet", action="store_true", help="Suppress non-error output.")
    verbosity.add_argument("--verbose", action="store_true", help="Print extra progress info.")


def build_parser() -> argparse.ArgumentParser:
    """Build the top-level argument parser with all subcommands."""
    parser = argparse.ArgumentParser(
        prog="iohmm-evac",
        description="Synthetic hurricane-evacuation DGP runner.",
    )
    sub = parser.add_subparsers(dest="command", required=True)

    p_sim = sub.add_parser("simulate", help="Run a simulation and write outputs.")
    _add_simulate_flags(p_sim)

    p_sc = sub.add_parser("scenarios", help="Inspect predefined scenarios.")
    p_sc_sub = p_sc.add_subparsers(dest="action", required=True)
    p_sc_sub.add_parser("list", help="List registered scenarios.")

    p_cfg = sub.add_parser("config", help="Inspect resolved configurations.")
    p_cfg_sub = p_cfg.add_subparsers(dest="action", required=True)
    p_cfg_dump = p_cfg_sub.add_parser("dump", help="Dump the resolved config as TOML.")
    p_cfg_dump.add_argument(
        "--scenario", choices=list_scenarios(), default="baseline", help="Scenario to dump."
    )

    add_fit_subparser(sub)
    add_diagnose_subparser(sub)
    add_report_subparser(sub)
    add_sweep_subparser(sub)
    add_bootstrap_subparser(sub)

    return parser


def _cmd_simulate(args: argparse.Namespace) -> int:
    config = _resolve_config(args)
    rng = np.random.default_rng(config.seed)

    if not args.quiet:
        print(
            f"Running simulation: scenario={args.scenario}, "
            f"N={config.n_households}, T={config.n_hours}, seed={config.seed}",
            file=sys.stderr,
        )

    result = simulate(config, rng)
    paths = write_results(result, args.output)

    if not args.quiet:
        for label, path in paths.items():
            print(f"{label}: {path}", file=sys.stderr)
    if args.verbose:
        n = result.population.n
        sh_share = float(np.mean(result.states[:, -1] == 4))
        print(f"final SH share: {sh_share:.3f} ({n} households)", file=sys.stderr)

    return 0


def _cmd_scenarios(args: argparse.Namespace) -> int:
    if args.action == "list":
        for name in list_scenarios():
            print(name)
        return 0
    msg = f"Unknown scenarios action: {args.action!r}"
    raise ValueError(msg)


def _cmd_config(args: argparse.Namespace) -> int:
    if args.action == "dump":
        config = build_scenario(args.scenario)
        nested = to_nested_dict(config)
        assert isinstance(nested, dict)
        sys.stdout.write(tomli_w.dumps(nested))
        return 0
    msg = f"Unknown config action: {args.action!r}"
    raise ValueError(msg)


def main(argv: Sequence[str] | None = None) -> int:
    """Run the CLI and return an exit code."""
    parser = build_parser()
    args = parser.parse_args(argv)
    if args.command == "simulate":
        return _cmd_simulate(args)
    if args.command == "scenarios":
        return _cmd_scenarios(args)
    if args.command == "config":
        return _cmd_config(args)
    if args.command == "fit":
        return run_fit(args)
    if args.command == "diagnose":
        return run_diagnose(args)
    if args.command == "report":
        return run_report(args)
    if args.command == "sweep":
        return run_sweep_command(args)
    if args.command == "bootstrap":
        return run_bootstrap_command(args)
    parser.error(f"Unknown command: {args.command!r}")  # pragma: no cover


if __name__ == "__main__":  # pragma: no cover
    raise SystemExit(main())