# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""Diagnostic plots for a :class:`SimulationBundle`.
Every plot is a pure function: it accepts a bundle and an optional
:class:`matplotlib.axes.Axes` (or sequence of axes) and returns the axes it
drew on. Functions never call :func:`matplotlib.pyplot.show` or
:meth:`matplotlib.figure.Figure.savefig` — that is the caller's job.
"""
from __future__ import annotations
from collections.abc import Sequence
from typing import TYPE_CHECKING, cast
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from iohmm_evac.report.constants import (
SCENARIO_COLORS,
SCENARIO_ORDER,
STATE_COLORS,
STATE_ORDER,
)
from iohmm_evac.report.loader import SimulationBundle, load_bundle
if TYPE_CHECKING:
from matplotlib.axes import Axes
from iohmm_evac.bootstrap.aggregate import BandResult
from iohmm_evac.sweep import SweepResult
__all__ = [
"plot_bootstrap_bands",
"plot_cumulative_departures",
"plot_emission_summary",
"plot_household_trajectories",
"plot_state_occupancy",
"plot_sweep_departures",
"plot_sweep_network",
]
_BAND_METRIC_TITLES: dict[str, str] = {
"failed_evacuation_count": "Failed evacuations (count)",
"peak_enroute_share": "Peak EnRoute share",
"total_delay_hours": "Total delay (hours)",
"shelter_overflow_count": "Shelter overflow (count)",
}
_MAX_TRAJECTORY_HOUSEHOLDS = 6
def _state_share_panel(observations: pd.DataFrame) -> pd.DataFrame:
"""Pivot observations into a (t by state) share table aligned to STATE_ORDER."""
counts = observations.groupby(["t", "state"]).size().unstack(fill_value=0).sort_index()
for s in STATE_ORDER:
if s not in counts.columns:
counts[s] = 0
counts = counts[list(STATE_ORDER)]
totals = counts.sum(axis=1).replace(0, 1)
return counts.div(totals, axis=0)
def _add_timeline_overlays(ax: Axes, bundle: SimulationBundle) -> None:
"""Draw vertical lines at voluntary, mandatory, and landfall hours."""
timeline = bundle.timeline
voluntary_mask = timeline["voluntary"].astype(bool).to_numpy()
mandatory_mask = timeline["mandatory"].astype(bool).to_numpy()
if voluntary_mask.any():
vol_hour = int(timeline.loc[voluntary_mask, "t"].iloc[0])
ax.axvline(vol_hour, ls="--", color="black", alpha=0.5, label="voluntary")
if mandatory_mask.any():
mand_hour = int(timeline.loc[mandatory_mask, "t"].iloc[0])
ax.axvline(mand_hour, ls="--", color="firebrick", alpha=0.6, label="mandatory")
ax.axvline(bundle.t_landfall, ls="-", color="black", alpha=0.7, label="landfall")
def plot_state_occupancy(bundle: SimulationBundle, ax: Axes | None = None) -> Axes:
"""Stacked area chart of state shares over time (Fig. 3).
Overlays vertical dashed lines at the voluntary and mandatory order hours
(read from the timeline DataFrame) and a solid line at landfall.
"""
if ax is None:
_, ax = plt.subplots(figsize=(10, 4))
shares = _state_share_panel(bundle.observations)
colors = [STATE_COLORS[s] for s in STATE_ORDER]
ax.stackplot(
shares.index.to_numpy(),
shares.to_numpy().T,
labels=list(STATE_ORDER),
colors=colors,
alpha=0.9,
)
_add_timeline_overlays(ax, bundle)
ax.set_xlim(0, bundle.t_landfall)
ax.set_ylim(0, 1)
ax.set_xlabel("Hours from start")
ax.set_ylabel("Population share")
ax.set_title("State occupancy over time")
ax.legend(loc="upper left", ncol=4, fontsize=8, framealpha=0.9)
return ax
def plot_cumulative_departures(bundle: SimulationBundle, ax: Axes | None = None) -> Axes:
"""Cumulative share of households that have ever departed, vs time (Fig. 4).
A household's "departure hour" is the first ``t`` where its latent
state is ER. Households that never enter ER never depart. We plot the
cumulative share whose departure hour is ``<= t``, against ``t``.
The ``departure`` emission column is *not* used here: it carries
Bernoulli noise (~3% per hour even from non-evacuating households),
which is appropriate for IO-HMM fitting in Build 2 but produces a
visually misleading curve for sanity-checking the underlying
behavioral dynamics.
"""
if ax is None:
_, ax = plt.subplots(figsize=(10, 4))
obs = bundle.observations
in_er = obs[obs["state"] == "ER"]
n_total = bundle.n_households
timeline_t = bundle.timeline["t"].to_numpy()
if in_er.empty:
cum = np.zeros_like(timeline_t, dtype=float)
else:
first_er = in_er.groupby("household_id")["t"].min()
counts = first_er.value_counts().sort_index()
per_hour = pd.Series(0, index=timeline_t, dtype=np.int64)
per_hour.loc[counts.index] = counts.to_numpy(dtype=np.int64)
cum = per_hour.cumsum().to_numpy(dtype=float) / float(max(n_total, 1))
ax.plot(timeline_t, cum, color="#1f77b4", lw=2.0, label="cumulative")
_add_timeline_overlays(ax, bundle)
ax.set_xlim(0, bundle.t_landfall)
ax.set_ylim(0, 1)
ax.set_xlabel("Hours from start")
ax.set_ylabel("Cumulative departure share")
ax.set_title("Cumulative departures")
ax.legend(loc="upper left", fontsize=8, framealpha=0.9)
return ax
def _state_codes(states: pd.Series) -> np.ndarray:
"""Map state-label strings to integer codes following STATE_ORDER."""
lookup = {s: i for i, s in enumerate(STATE_ORDER)}
out: np.ndarray = states.map(lookup).to_numpy(dtype=np.int64)
return out
def _validate_household_ids(bundle: SimulationBundle, household_ids: Sequence[int]) -> list[int]:
"""Verify that the requested IDs exist and are not too many to plot."""
ids = list(household_ids)
if not ids:
msg = "household_ids must contain at least one id"
raise ValueError(msg)
if len(ids) > _MAX_TRAJECTORY_HOUSEHOLDS:
msg = (
f"plot_household_trajectories supports at most "
f"{_MAX_TRAJECTORY_HOUSEHOLDS} households (got {len(ids)})"
)
raise ValueError(msg)
available = set(bundle.population["household_id"].astype(int).tolist())
missing = [i for i in ids if int(i) not in available]
if missing:
msg = f"Households not found in bundle: {missing}"
raise ValueError(msg)
return [int(i) for i in ids]
def _draw_trajectory(ax: Axes, bundle: SimulationBundle, hh_id: int) -> None:
"""Render a single household's trajectory panel."""
sub = bundle.observations[bundle.observations["household_id"] == hh_id].sort_values("t")
t_arr = sub["t"].to_numpy()
state_codes = _state_codes(sub["state"])
ax.plot(
bundle.timeline["t"].to_numpy(),
bundle.timeline["forecast"].to_numpy(),
color="#777777",
lw=1.0,
label="forecast",
)
ax.step(t_arr, state_codes, where="post", color="black", lw=1.5, label="state")
# The X mark is the household's actual departure hour: the first ``t``
# at which the latent state is ER. The noisy ``departure`` emission
# column is intentionally ignored here (see plot_cumulative_departures
# for rationale).
er_hours = sub.loc[sub["state"] == "ER", "t"].to_numpy()
if er_hours.size:
first_er = int(er_hours.min())
ax.scatter(
[first_er],
[len(STATE_ORDER) - 0.5],
marker="x",
color="#d73027",
s=40,
label="departure",
)
disp = sub["displacement"].to_numpy()
if np.nanmax(disp) > 0:
scaled = disp / max(float(np.nanmax(disp)), 1e-9) * (len(STATE_ORDER) - 1)
ax.plot(t_arr, scaled, color="#1a9850", lw=1.0, alpha=0.6, label="displacement")
timeline = bundle.timeline
if timeline["voluntary"].astype(bool).any():
vol_t = int(timeline.loc[timeline["voluntary"].astype(bool), "t"].iloc[0])
ax.axvline(vol_t, ls="--", color="black", alpha=0.4)
if timeline["mandatory"].astype(bool).any():
mand_t = int(timeline.loc[timeline["mandatory"].astype(bool), "t"].iloc[0])
ax.axvline(mand_t, ls="--", color="firebrick", alpha=0.5)
ax.set_yticks(range(len(STATE_ORDER)))
ax.set_yticklabels(list(STATE_ORDER))
ax.set_xlim(0, bundle.t_landfall)
ax.set_title(f"Household {hh_id}")
ax.legend(loc="upper left", fontsize=7, framealpha=0.85)
def plot_household_trajectories(
bundle: SimulationBundle,
household_ids: Sequence[int],
ax: Sequence[Axes] | None = None,
) -> Sequence[Axes]:
"""Multi-panel forecast / state / departure / displacement plot (Fig. 2).
If ``ax`` is supplied, it must be a sequence of axes whose length matches
``household_ids``. Otherwise a new figure with one subplot per household
is created.
"""
ids = _validate_household_ids(bundle, household_ids)
if ax is None:
_, axes_obj = plt.subplots(len(ids), 1, figsize=(10, 2.5 * len(ids)), sharex=True)
axes: list[Axes] = (
[cast("Axes", axes_obj)] if len(ids) == 1 else list(np.atleast_1d(axes_obj))
)
else:
axes = list(ax)
if len(axes) != len(ids):
msg = (
f"ax must have one entry per household_id (got {len(axes)} axes for {len(ids)} ids)"
)
raise ValueError(msg)
for axis, hh_id in zip(axes, ids, strict=True):
_draw_trajectory(axis, bundle, hh_id)
axes[-1].set_xlabel("Hours from start")
return axes
def _emission_summary_table(observations: pd.DataFrame) -> pd.DataFrame:
"""Per-state means of departure, displacement, comm_count."""
agg = observations.groupby("state")[["departure", "displacement", "comm_count"]].mean()
rows = [s for s in STATE_ORDER if s in agg.index]
return agg.loc[rows]
def plot_emission_summary(bundle: SimulationBundle, ax: Axes | None = None) -> Axes:
"""Draw a grouped bar chart of per-state emission means (sanity check)."""
if ax is None:
_, ax = plt.subplots(figsize=(8, 4))
summary = _emission_summary_table(bundle.observations)
metrics = ["departure", "displacement", "comm_count"]
n_states = summary.shape[0]
width = 0.25
x = np.arange(n_states, dtype=float)
palette = ["#4575b4", "#fdae61", "#762a83"]
for i, metric in enumerate(metrics):
ax.bar(
x + (i - 1) * width,
summary[metric].to_numpy(),
width=width,
label=metric,
color=palette[i],
)
ax.set_xticks(x)
ax.set_xticklabels(list(summary.index))
ax.set_ylabel("Mean per (household, t) row")
ax.set_title("Emission summary by state")
ax.legend(loc="upper left", fontsize=8, framealpha=0.9)
return ax
def _scenario_warning_hours(bundle: SimulationBundle) -> tuple[int | None, int | None]:
"""Return the first (voluntary, mandatory) warning hours from a bundle's timeline."""
timeline = bundle.timeline
vol_mask = timeline["voluntary"].astype(bool).to_numpy()
mand_mask = timeline["mandatory"].astype(bool).to_numpy()
vol = int(timeline.loc[vol_mask, "t"].iloc[0]) if vol_mask.any() else None
mand = int(timeline.loc[mand_mask, "t"].iloc[0]) if mand_mask.any() else None
return vol, mand
def _ordered_sweep_scenarios(sweep: SweepResult) -> list[str]:
"""Order the sweep's scenarios by SCENARIO_ORDER, appending unknowns at the end."""
present = list(sweep.config.scenarios)
ordered = [s for s in SCENARIO_ORDER if s in present]
extra = [s for s in present if s not in SCENARIO_ORDER]
return ordered + extra
def _cumulative_share(bundle: SimulationBundle) -> tuple[np.ndarray, np.ndarray]:
"""Return (timeline_t, cumulative-departure-share) for a single bundle."""
obs = bundle.observations
in_er = obs[obs["state"] == "ER"]
n_total = bundle.n_households
timeline_t = bundle.timeline["t"].to_numpy()
if in_er.empty:
return timeline_t, np.zeros_like(timeline_t, dtype=float)
first_er = in_er.groupby("household_id")["t"].min()
counts = first_er.value_counts().sort_index()
per_hour = pd.Series(0, index=timeline_t, dtype=np.int64)
per_hour.loc[counts.index] = counts.to_numpy(dtype=np.int64)
cum = per_hour.cumsum().to_numpy(dtype=float) / float(max(n_total, 1))
return timeline_t, cum
def plot_sweep_departures(sweep: SweepResult, ax: Axes | None = None) -> Axes:
"""Overlay cumulative-departure curves across scenarios (Fig. 4).
Each scenario gets its own line in the colorblind-friendly
:data:`SCENARIO_COLORS` palette; the legend label includes the scenario's
voluntary / mandatory warning hours. The landfall hour is drawn as a
single solid vertical reference; per-scenario warning verticals are
deliberately omitted because they differ across scenarios.
"""
if ax is None:
_, ax = plt.subplots(figsize=(10, 4))
ordered = _ordered_sweep_scenarios(sweep)
landfall: int | None = None
for scenario in ordered:
bundle = load_bundle(sweep.bundles[scenario])
timeline_t, cum = _cumulative_share(bundle)
vol, mand = _scenario_warning_hours(bundle)
landfall = bundle.t_landfall if landfall is None else landfall
vol_txt = f"{vol}" if vol is not None else "—"
mand_txt = f"{mand}" if mand is not None else "—"
label = f"{scenario} (vol={vol_txt}, mand={mand_txt})"
color = SCENARIO_COLORS.get(scenario, None)
ax.plot(timeline_t, cum, color=color, lw=2.0, label=label)
if landfall is not None:
ax.axvline(landfall, ls="-", color="black", alpha=0.7, label="landfall")
ax.set_xlim(0, landfall)
ax.set_ylim(0, 1)
ax.set_xlabel("Hours from start")
ax.set_ylabel("Cumulative departure share")
ax.set_title("Cumulative departures by scenario")
ax.legend(loc="upper left", fontsize=8, framealpha=0.9)
return ax
def _draw_metric_panel(
ax: Axes,
scenarios: list[str],
values: list[float],
title: str,
labels: list[str],
) -> None:
y_pos = np.arange(len(scenarios), dtype=float)
colors = [SCENARIO_COLORS.get(s, "#777777") for s in scenarios]
ax.barh(y_pos, values, color=colors, edgecolor="black", linewidth=0.5)
ax.set_yticks(y_pos)
ax.set_yticklabels(scenarios)
ax.invert_yaxis()
ax.set_title(title)
span = max(values) if values else 0.0
pad = max(span * 0.02, 1e-3)
for y, v, label in zip(y_pos, values, labels, strict=True):
ax.text(v + pad, y, label, va="center", fontsize=8)
if span > 0:
ax.set_xlim(0, span * 1.18)
def plot_sweep_network(
sweep: SweepResult,
ax: Sequence[Sequence[Axes]] | np.ndarray | None = None,
) -> np.ndarray:
"""2x2 panel of per-scenario network metrics (Fig. 5).
Panels (top-left → bottom-right): total delay, peak EnRoute share,
shelter overflow, failed evacuations. The peak-share panel labels each
bar with both the share and the hour at which it is attained.
If ``ax`` is None, a new figure with a 2x2 layout is created. Otherwise
``ax`` must be a 2x2 sequence/ndarray of axes.
"""
if ax is None:
_, axes_obj = plt.subplots(2, 2, figsize=(10, 6))
axes_arr = np.asarray(axes_obj, dtype=object)
else:
axes_arr = np.asarray(ax, dtype=object)
if axes_arr.shape != (2, 2):
msg = f"ax must be a 2x2 array of axes (got shape {axes_arr.shape})"
raise ValueError(msg)
scenarios = _ordered_sweep_scenarios(sweep)
metrics = sweep.network_metrics
delay_values = [float(metrics[s].total_delay_hours) for s in scenarios]
delay_labels = [f"{v:.1f}" for v in delay_values]
_draw_metric_panel(
cast("Axes", axes_arr[0, 0]),
scenarios,
delay_values,
"Total delay (hours)",
delay_labels,
)
peak_values = [float(metrics[s].peak_enroute_share) for s in scenarios]
peak_labels = [
f"{metrics[s].peak_enroute_share:.3f} @ t={metrics[s].peak_enroute_hour}" for s in scenarios
]
_draw_metric_panel(
cast("Axes", axes_arr[0, 1]),
scenarios,
peak_values,
"Peak EnRoute share",
peak_labels,
)
overflow_values = [float(metrics[s].shelter_overflow_count) for s in scenarios]
overflow_labels = [f"{int(v)}" for v in overflow_values]
_draw_metric_panel(
cast("Axes", axes_arr[1, 0]),
scenarios,
overflow_values,
"Shelter overflow (count)",
overflow_labels,
)
failed_values = [float(metrics[s].failed_evacuation_count) for s in scenarios]
failed_labels = [f"{int(v)}" for v in failed_values]
_draw_metric_panel(
cast("Axes", axes_arr[1, 1]),
scenarios,
failed_values,
"Failed evacuations (count)",
failed_labels,
)
return axes_arr
def plot_bootstrap_bands(
band_result: BandResult,
metric: str = "failed_evacuation_count",
ax: Axes | None = None,
) -> Axes:
"""Render Fig. 6: median + 25–75% and 5–95% quantile bands across shifts.
The solid line is the 50th-percentile, the inner shaded band is the
25–75% range, and the outer shaded band is the 5–95% range. A vertical
dashed line at ``δ = 0`` marks baseline timing.
"""
if metric not in band_result.bands:
msg = f"Metric {metric!r} not in BandResult. Known: {tuple(band_result.bands.keys())}"
raise KeyError(msg)
required = {5, 25, 50, 75, 95}
missing = required.difference(band_result.percentiles)
if missing:
msg = (
f"BandResult is missing percentiles {sorted(missing)} required for plot_bootstrap_bands"
)
raise ValueError(msg)
if ax is None:
_, ax = plt.subplots(figsize=(8, 4.5))
shifts = np.asarray(band_result.shifts, dtype=np.float64)
p5 = band_result.quantile(metric, 5)
p25 = band_result.quantile(metric, 25)
p50 = band_result.quantile(metric, 50)
p75 = band_result.quantile(metric, 75)
p95 = band_result.quantile(metric, 95)
ax.fill_between(shifts, p5, p95, color="#4575b4", alpha=0.20, label="5–95%")
ax.fill_between(shifts, p25, p75, color="#4575b4", alpha=0.40, label="25–75%")
ax.plot(shifts, p50, color="#1a3a73", lw=2.0, label="median")
ax.axvline(0.0, ls="--", color="black", alpha=0.6, label="baseline timing")
ax.set_xlabel("Warning lead-time shift δ (hours)")
ax.set_ylabel(_BAND_METRIC_TITLES.get(metric, metric))
ax.set_title(f"Bootstrap bands: {_BAND_METRIC_TITLES.get(metric, metric)}")
ax.legend(loc="best", fontsize=8, framealpha=0.9)
return ax