# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""Adapter: turn a :class:`SimulationBundle` into IO-HMM input/output arrays."""
from __future__ import annotations
from dataclasses import dataclass
import numpy as np
from iohmm_evac.inference.fit_params import FEATURE_NAMES, F, K
from iohmm_evac.report.loader import SimulationBundle
from iohmm_evac.types import FloatArray, IntArray, State
__all__ = ["FitData", "bundle_to_fit_data"]
@dataclass(frozen=True, slots=True)
class FitData:
"""Per-household inputs ``u`` and emission observations ``y``."""
inputs: FloatArray
"""Inputs ``u_{i,t}``, shape ``(N, T+1, F)``."""
departure: FloatArray
"""Bernoulli ``D_{i,t}`` (cast to float for weighted MLE), shape ``(N, T+1)``."""
displacement: FloatArray
"""Gaussian ``X_{i,t}``, shape ``(N, T+1)``."""
comm: FloatArray
"""Poisson ``C_{i,t}`` (kept as float for weighted MLE), shape ``(N, T+1)``."""
true_states: IntArray | None
"""Optional true state path for diagnostics, shape ``(N, T+1)``."""
@property
def n(self) -> int:
"""Number of households."""
return int(self.departure.shape[0])
@property
def t_total(self) -> int:
"""Number of forward steps; the time axis has ``T + 1`` slots."""
return int(self.departure.shape[1] - 1)
@property
def k(self) -> int:
"""Number of latent states (always ``K``)."""
return K
@property
def f(self) -> int:
"""Input dimension (always ``F``)."""
return F
_STATE_TO_CODE: dict[str, int] = {s.name: int(s) for s in State}
def bundle_to_fit_data(bundle: SimulationBundle) -> FitData:
"""Project the bundle's panel into the IO-HMM's array layout."""
obs = bundle.observations.sort_values(["household_id", "t"])
pop = bundle.population.sort_values("household_id").reset_index(drop=True)
timeline = bundle.timeline.sort_values("t").reset_index(drop=True)
n_households = int(pop.shape[0])
t_plus_1 = int(timeline.shape[0])
if obs.shape[0] != n_households * t_plus_1:
msg = (
"Observation row count "
f"({obs.shape[0]}) does not match N*T+1 = {n_households * t_plus_1}"
)
raise ValueError(msg)
departure = (
obs["departure"].to_numpy(dtype=bool).reshape(n_households, t_plus_1).astype(np.float64)
)
displacement = obs["displacement"].to_numpy(dtype=np.float64).reshape(n_households, t_plus_1)
comm = obs["comm_count"].to_numpy(dtype=np.float64).reshape(n_households, t_plus_1)
states_str = obs["state"].to_numpy()
true_states = np.array([_STATE_TO_CODE[s] for s in states_str], dtype=np.int64).reshape(
n_households, t_plus_1
)
forecast = timeline["forecast"].to_numpy(dtype=np.float64)
voluntary = timeline["voluntary"].to_numpy(dtype=bool).astype(np.float64)
mandatory = timeline["mandatory"].to_numpy(dtype=bool).astype(np.float64)
distance = pop["distance_km"].to_numpy(dtype=np.float64)
risk = pop["risk"].to_numpy(dtype=np.float64)
vehicle = pop["vehicle"].to_numpy(dtype=bool).astype(np.float64)
t_total = t_plus_1 - 1
inputs = np.zeros((n_households, t_plus_1, F), dtype=np.float64)
idx = {name: i for i, name in enumerate(FEATURE_NAMES)}
inputs[:, :, idx["vol"]] = voluntary[None, :]
inputs[:, :, idx["mand"]] = mandatory[None, :]
inputs[:, :, idx["rho"]] = forecast[None, :] * np.exp(-distance[:, None] / 10.0)
inputs[:, :, idx["r"]] = risk[:, None]
inputs[:, :, idx["v"]] = vehicle[:, None]
if t_total > 0:
tau = np.arange(t_plus_1, dtype=np.float64) / float(t_total)
else:
tau = np.zeros(t_plus_1, dtype=np.float64)
inputs[:, :, idx["tau"]] = tau[None, :]
return FitData(
inputs=inputs,
departure=departure,
displacement=displacement,
comm=comm,
true_states=true_states,
)