# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""High-level entry point: ``SimulationBundle`` → :class:`NetworkMetrics`."""
from __future__ import annotations
from typing import Any
import numpy as np
import pandas as pd
from iohmm_evac.network.metrics import NetworkMetrics, compute_metrics_from_arrays
from iohmm_evac.report.constants import STATE_ORDER
from iohmm_evac.report.loader import SimulationBundle
from iohmm_evac.types import FloatArray, IntArray, State
__all__ = ["compute_network_metrics"]
_EVAC_PATH_CODE: dict[str, int] = {"none": 0, "away": 1, "home": 2}
def _wide_arrays(bundle: SimulationBundle) -> tuple[IntArray, FloatArray, IntArray]:
"""Pivot the long observations DataFrame back into wide ``(N, T+1)`` arrays."""
obs = bundle.observations
state_lookup = {label: int(getattr(State, label)) for label in STATE_ORDER}
state_codes = obs["state"].map(state_lookup).to_numpy(dtype=np.int64)
df = pd.DataFrame(
{
"household_id": obs["household_id"].to_numpy(dtype=np.int64),
"t": obs["t"].to_numpy(dtype=np.int64),
"state_code": state_codes,
"displacement": obs["displacement"].to_numpy(dtype=np.float64),
}
)
states_pivot = df.pivot(index="household_id", columns="t", values="state_code").sort_index()
disp_pivot = df.pivot(index="household_id", columns="t", values="displacement").sort_index()
states_pivot = states_pivot.reindex(columns=sorted(states_pivot.columns))
disp_pivot = disp_pivot.reindex(columns=sorted(disp_pivot.columns))
pop = bundle.population.sort_values("household_id")
evac_path_codes = pop["evac_path"].map(_EVAC_PATH_CODE).to_numpy(dtype=np.int64)
return (
states_pivot.to_numpy(dtype=np.int64),
disp_pivot.to_numpy(dtype=np.float64),
evac_path_codes,
)
def _config_value(config: dict[str, Any], path: tuple[str, ...], default: float) -> float:
cur: Any = config
for key in path:
if not isinstance(cur, dict) or key not in cur:
return default
cur = cur[key]
return float(cur)
def compute_network_metrics(bundle: SimulationBundle) -> NetworkMetrics:
"""Compute :class:`NetworkMetrics` from a loaded simulation bundle.
Reads ``feedback.n_cap``, ``feedback.shelter_capacity``,
``emissions.v_free``, and ``emissions.congestion_penalty`` from the
bundle's sidecar config TOML so that scenario-level overrides
(notably ``contraflow``'s larger ``n_cap``) flow through correctly.
"""
states, displacements, evac_path = _wide_arrays(bundle)
n_cap = int(_config_value(bundle.config, ("feedback", "n_cap"), 1500))
shelter_capacity = int(_config_value(bundle.config, ("feedback", "shelter_capacity"), 3000))
v_free = _config_value(bundle.config, ("emissions", "v_free"), 40.0)
congestion_penalty = _config_value(bundle.config, ("emissions", "congestion_penalty"), 0.6)
return compute_metrics_from_arrays(
states=states,
displacements=displacements,
evac_path=evac_path,
n_cap=n_cap,
shelter_capacity=shelter_capacity,
v_free=v_free,
congestion_penalty=congestion_penalty,
)