src/iohmm_evac/params.py

# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""Frozen dataclasses bundling DGP parameters.

Every numeric default in the chapter spec lives here; nothing else should
hardcode parameter values. Override at the CLI via ``--set`` or via a TOML
config file.
"""

from __future__ import annotations

from dataclasses import dataclass, field, fields, is_dataclass
from typing import Any

__all__ = [
    "EmissionParams",
    "FeedbackParams",
    "PopulationParams",
    "SimulationConfig",
    "TimelineParams",
    "TransitionParams",
    "TransitionRow",
]


@dataclass(frozen=True, slots=True)
class PopulationParams:
    """Parameters of the household-covariate distributions."""

    distance_mu: float = 15.0
    distance_sigma: float = 10.0
    distance_lo: float = 0.5
    distance_hi: float = 50.0
    vehicle_p: float = 0.85
    risk_mu: float = 0.0
    risk_sigma: float = 1.0
    zone_a_threshold: float = 5.0
    zone_b_threshold: float = 20.0
    dest_lo: float = 40.0
    dest_hi: float = 100.0
    targeted_zone_multiplier: float = 1.0
    """Risk multiplier for zone A (used by the targeted-messaging scenario)."""


@dataclass(frozen=True, slots=True)
class TimelineParams:
    """Forecast and warning-order timeline parameters."""

    forecast_breakpoints: tuple[int, ...] = (0, 48, 72, 96)
    forecast_levels: tuple[float, ...] = (1.0, 2.0, 3.0, 4.0)
    forecast_noise_sigma: float = 0.15
    voluntary_hour: int = 60
    mandatory_hour: int = 84


@dataclass(frozen=True, slots=True)
class TransitionRow:
    """Coefficients for a single multinomial-logit transition row.

    Unused features are simply zero. Forbidden destinations are encoded
    elsewhere (in the allowed-transition mask), so all rows here describe
    *allowed* destinations only.
    """

    alpha: float
    beta_vol: float = 0.0
    beta_mand: float = 0.0
    beta_rho: float = 0.0
    beta_pi: float = 0.0
    beta_r: float = 0.0
    beta_v: float = 0.0
    beta_tau: float = 0.0
    beta_negc: float = 0.0
    beta_negr: float = 0.0
    beta_negv: float = 0.0
    beta_tir: float = 0.0


@dataclass(frozen=True, slots=True)
class TransitionParams:
    """Logit coefficients for every allowed (origin → destination) row.

    Self-loops (k -> k) implicitly have logit 0. Forbidden destinations
    have logit -inf and are masked in the sampler.
    """

    ua_to_aw: TransitionRow = field(
        default_factory=lambda: TransitionRow(
            alpha=-6.5, beta_vol=1.5, beta_mand=2.5, beta_rho=0.5, beta_r=0.4, beta_tau=1.0
        )
    )
    aw_to_ua: TransitionRow = field(default_factory=lambda: TransitionRow(alpha=-3.0))
    aw_to_pr: TransitionRow = field(
        default_factory=lambda: TransitionRow(
            alpha=-6.0,
            beta_mand=2.0,
            beta_rho=0.6,
            beta_pi=1.5,
            beta_r=0.5,
            beta_v=0.6,
            beta_tau=1.5,
        )
    )
    pr_to_er: TransitionRow = field(
        default_factory=lambda: TransitionRow(
            alpha=-5.5, beta_mand=1.5, beta_tau=2.0, beta_negc=1.5, beta_r=0.4, beta_v=0.8
        )
    )
    pr_to_sh: TransitionRow = field(
        default_factory=lambda: TransitionRow(alpha=-6.0, beta_negr=0.6, beta_negv=0.8)
    )
    er_to_sh: TransitionRow = field(
        default_factory=lambda: TransitionRow(alpha=-3.0, beta_tir=1.2, beta_negc=1.0)
    )


@dataclass(frozen=True, slots=True)
class EmissionParams:
    """State-conditional emission parameters."""

    p_departure_er: float = 0.95
    p_departure_other: float = 0.03
    displacement_idle_sigma: float = 0.5
    displacement_route_sigma: float = 1.0
    displacement_destination_sigma: float = 0.5
    v_free: float = 40.0
    congestion_penalty: float = 0.6
    lambda_ua: float = 0.2
    lambda_aw: float = 1.5
    lambda_pr: float = 4.0
    lambda_er: float = 2.0
    lambda_sh: float = 0.5


@dataclass(frozen=True, slots=True)
class FeedbackParams:
    """Endogenous-feedback parameters."""

    n_cap: int = 1500
    shelter_capacity: int = 3000  # held for later builds; not used by the DGP itself.


@dataclass(frozen=True, slots=True)
class SimulationConfig:
    """Top-level simulation configuration."""

    n_households: int = 10_000
    n_hours: int = 120
    seed: int = 0
    population: PopulationParams = field(default_factory=PopulationParams)
    timeline: TimelineParams = field(default_factory=TimelineParams)
    transitions: TransitionParams = field(default_factory=TransitionParams)
    emissions: EmissionParams = field(default_factory=EmissionParams)
    feedback: FeedbackParams = field(default_factory=FeedbackParams)


def to_nested_dict(obj: Any) -> Any:
    """Recursively convert a dataclass tree into a nested dict.

    Tuples are preserved as lists for TOML compatibility.
    """
    if is_dataclass(obj) and not isinstance(obj, type):
        return {f.name: to_nested_dict(getattr(obj, f.name)) for f in fields(obj)}
    if isinstance(obj, tuple | list):
        return [to_nested_dict(v) for v in obj]
    return obj