tests/test_bootstrap_plot.py

# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""Tests for plot_bootstrap_bands."""

from __future__ import annotations

import matplotlib.pyplot as plt
import numpy as np
import pytest
from matplotlib.collections import PolyCollection

from iohmm_evac.bootstrap.aggregate import compute_bands
from iohmm_evac.bootstrap.shift_sweep import ShiftSweepResult, SweepRow
from iohmm_evac.report.plots import plot_bootstrap_bands


def _tiny_sweep() -> ShiftSweepResult:
    rng = np.random.default_rng(0)
    rows: list[SweepRow] = []
    for rep in range(5):
        for shift in (-8, 0, 8):
            failed = int(120 - 4 * shift + rng.integers(-3, 4))
            rows.append(
                SweepRow(
                    replicate_id=rep,
                    shift=shift,
                    failed_evacuation_count=failed,
                    peak_enroute_share=0.1 + 0.001 * shift,
                    total_delay_hours=10.0 - 0.1 * shift,
                    shelter_overflow_count=int(50 - shift),
                )
            )
    return ShiftSweepResult(rows=tuple(rows), shifts=(-8, 0, 8), n_replicates=5)


def test_plot_bootstrap_bands_renders() -> None:
    sweep = _tiny_sweep()
    bands = compute_bands(sweep, percentiles=(5, 25, 50, 75, 95))
    fig, ax = plt.subplots()
    out = plot_bootstrap_bands(bands, metric="failed_evacuation_count", ax=ax)
    assert out is ax
    fills = [c for c in ax.collections if isinstance(c, PolyCollection)]
    assert len(fills) >= 2  # outer + inner band
    plt.close(fig)


def test_plot_bootstrap_bands_creates_axes_when_none() -> None:
    sweep = _tiny_sweep()
    bands = compute_bands(sweep, percentiles=(5, 25, 50, 75, 95))
    ax = plot_bootstrap_bands(bands, metric="peak_enroute_share")
    assert ax.get_xlabel() != ""
    plt.close("all")


def test_plot_bootstrap_bands_unknown_metric() -> None:
    sweep = _tiny_sweep()
    bands = compute_bands(sweep, percentiles=(5, 25, 50, 75, 95))
    with pytest.raises(KeyError, match="not in BandResult"):
        plot_bootstrap_bands(bands, metric="not-a-metric")


def test_plot_bootstrap_bands_missing_percentiles() -> None:
    sweep = _tiny_sweep()
    bands = compute_bands(sweep, percentiles=(50,))
    with pytest.raises(ValueError, match="missing percentiles"):
        plot_bootstrap_bands(bands, metric="failed_evacuation_count")