src/iohmm_evac/inference/data.py

# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""Adapter: turn a :class:`SimulationBundle` into IO-HMM input/output arrays."""

from __future__ import annotations

from dataclasses import dataclass

import numpy as np

from iohmm_evac.inference.fit_params import FEATURE_NAMES, F, K
from iohmm_evac.report.loader import SimulationBundle
from iohmm_evac.types import FloatArray, IntArray, State

__all__ = ["FitData", "bundle_to_fit_data"]


@dataclass(frozen=True, slots=True)
class FitData:
    """Per-household inputs ``u`` and emission observations ``y``."""

    inputs: FloatArray
    """Inputs ``u_{i,t}``, shape ``(N, T+1, F)``."""
    departure: FloatArray
    """Bernoulli ``D_{i,t}`` (cast to float for weighted MLE), shape ``(N, T+1)``."""
    displacement: FloatArray
    """Gaussian ``X_{i,t}``, shape ``(N, T+1)``."""
    comm: FloatArray
    """Poisson ``C_{i,t}`` (kept as float for weighted MLE), shape ``(N, T+1)``."""
    true_states: IntArray | None
    """Optional true state path for diagnostics, shape ``(N, T+1)``."""

    @property
    def n(self) -> int:
        """Number of households."""
        return int(self.departure.shape[0])

    @property
    def t_total(self) -> int:
        """Number of forward steps; the time axis has ``T + 1`` slots."""
        return int(self.departure.shape[1] - 1)

    @property
    def k(self) -> int:
        """Number of latent states (always ``K``)."""
        return K

    @property
    def f(self) -> int:
        """Input dimension (always ``F``)."""
        return F


_STATE_TO_CODE: dict[str, int] = {s.name: int(s) for s in State}


def bundle_to_fit_data(bundle: SimulationBundle) -> FitData:
    """Project the bundle's panel into the IO-HMM's array layout."""
    obs = bundle.observations.sort_values(["household_id", "t"])
    pop = bundle.population.sort_values("household_id").reset_index(drop=True)
    timeline = bundle.timeline.sort_values("t").reset_index(drop=True)

    n_households = int(pop.shape[0])
    t_plus_1 = int(timeline.shape[0])
    if obs.shape[0] != n_households * t_plus_1:
        msg = (
            "Observation row count "
            f"({obs.shape[0]}) does not match N*T+1 = {n_households * t_plus_1}"
        )
        raise ValueError(msg)

    departure = (
        obs["departure"].to_numpy(dtype=bool).reshape(n_households, t_plus_1).astype(np.float64)
    )
    displacement = obs["displacement"].to_numpy(dtype=np.float64).reshape(n_households, t_plus_1)
    comm = obs["comm_count"].to_numpy(dtype=np.float64).reshape(n_households, t_plus_1)

    states_str = obs["state"].to_numpy()
    true_states = np.array([_STATE_TO_CODE[s] for s in states_str], dtype=np.int64).reshape(
        n_households, t_plus_1
    )

    forecast = timeline["forecast"].to_numpy(dtype=np.float64)
    voluntary = timeline["voluntary"].to_numpy(dtype=bool).astype(np.float64)
    mandatory = timeline["mandatory"].to_numpy(dtype=bool).astype(np.float64)
    distance = pop["distance_km"].to_numpy(dtype=np.float64)
    risk = pop["risk"].to_numpy(dtype=np.float64)
    vehicle = pop["vehicle"].to_numpy(dtype=bool).astype(np.float64)
    t_total = t_plus_1 - 1

    inputs = np.zeros((n_households, t_plus_1, F), dtype=np.float64)
    idx = {name: i for i, name in enumerate(FEATURE_NAMES)}

    inputs[:, :, idx["vol"]] = voluntary[None, :]
    inputs[:, :, idx["mand"]] = mandatory[None, :]
    inputs[:, :, idx["rho"]] = forecast[None, :] * np.exp(-distance[:, None] / 10.0)
    inputs[:, :, idx["r"]] = risk[:, None]
    inputs[:, :, idx["v"]] = vehicle[:, None]
    if t_total > 0:
        tau = np.arange(t_plus_1, dtype=np.float64) / float(t_total)
    else:
        tau = np.zeros(t_plus_1, dtype=np.float64)
    inputs[:, :, idx["tau"]] = tau[None, :]

    return FitData(
        inputs=inputs,
        departure=departure,
        displacement=displacement,
        comm=comm,
        true_states=true_states,
    )