src/iohmm_evac/network/apply.py

# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""High-level entry point: ``SimulationBundle`` → :class:`NetworkMetrics`."""

from __future__ import annotations

from typing import Any

import numpy as np
import pandas as pd

from iohmm_evac.network.metrics import NetworkMetrics, compute_metrics_from_arrays
from iohmm_evac.report.constants import STATE_ORDER
from iohmm_evac.report.loader import SimulationBundle
from iohmm_evac.types import FloatArray, IntArray, State

__all__ = ["compute_network_metrics"]


_EVAC_PATH_CODE: dict[str, int] = {"none": 0, "away": 1, "home": 2}


def _wide_arrays(bundle: SimulationBundle) -> tuple[IntArray, FloatArray, IntArray]:
    """Pivot the long observations DataFrame back into wide ``(N, T+1)`` arrays."""
    obs = bundle.observations
    state_lookup = {label: int(getattr(State, label)) for label in STATE_ORDER}
    state_codes = obs["state"].map(state_lookup).to_numpy(dtype=np.int64)
    df = pd.DataFrame(
        {
            "household_id": obs["household_id"].to_numpy(dtype=np.int64),
            "t": obs["t"].to_numpy(dtype=np.int64),
            "state_code": state_codes,
            "displacement": obs["displacement"].to_numpy(dtype=np.float64),
        }
    )
    states_pivot = df.pivot(index="household_id", columns="t", values="state_code").sort_index()
    disp_pivot = df.pivot(index="household_id", columns="t", values="displacement").sort_index()
    states_pivot = states_pivot.reindex(columns=sorted(states_pivot.columns))
    disp_pivot = disp_pivot.reindex(columns=sorted(disp_pivot.columns))

    pop = bundle.population.sort_values("household_id")
    evac_path_codes = pop["evac_path"].map(_EVAC_PATH_CODE).to_numpy(dtype=np.int64)

    return (
        states_pivot.to_numpy(dtype=np.int64),
        disp_pivot.to_numpy(dtype=np.float64),
        evac_path_codes,
    )


def _config_value(config: dict[str, Any], path: tuple[str, ...], default: float) -> float:
    cur: Any = config
    for key in path:
        if not isinstance(cur, dict) or key not in cur:
            return default
        cur = cur[key]
    return float(cur)


def compute_network_metrics(bundle: SimulationBundle) -> NetworkMetrics:
    """Compute :class:`NetworkMetrics` from a loaded simulation bundle.

    Reads ``feedback.n_cap``, ``feedback.shelter_capacity``,
    ``emissions.v_free``, and ``emissions.congestion_penalty`` from the
    bundle's sidecar config TOML so that scenario-level overrides
    (notably ``contraflow``'s larger ``n_cap``) flow through correctly.
    """
    states, displacements, evac_path = _wide_arrays(bundle)
    n_cap = int(_config_value(bundle.config, ("feedback", "n_cap"), 1500))
    shelter_capacity = int(_config_value(bundle.config, ("feedback", "shelter_capacity"), 3000))
    v_free = _config_value(bundle.config, ("emissions", "v_free"), 40.0)
    congestion_penalty = _config_value(bundle.config, ("emissions", "congestion_penalty"), 0.6)
    return compute_metrics_from_arrays(
        states=states,
        displacements=displacements,
        evac_path=evac_path,
        n_cap=n_cap,
        shelter_capacity=shelter_capacity,
        v_free=v_free,
        congestion_penalty=congestion_penalty,
    )