tests/test_report_plots.py

# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
from __future__ import annotations

from collections.abc import Iterator
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pytest
from matplotlib.axes import Axes

from iohmm_evac.dgp.simulator import simulate
from iohmm_evac.io import write_results
from iohmm_evac.params import SimulationConfig
from iohmm_evac.report.constants import STATE_ORDER
from iohmm_evac.report.loader import SimulationBundle, load_bundle
from iohmm_evac.report.plots import (
    plot_cumulative_departures,
    plot_emission_summary,
    plot_household_trajectories,
    plot_state_occupancy,
)


@pytest.fixture(autouse=True)
def _close_figures() -> Iterator[None]:
    yield
    plt.close("all")


@pytest.fixture
def bundle(tmp_path: Path) -> SimulationBundle:
    config = SimulationConfig(n_households=50, n_hours=24, seed=0)
    rng = np.random.default_rng(config.seed)
    result = simulate(config, rng)
    out = tmp_path / "fixture.parquet"
    write_results(result, out)
    return load_bundle(out)


def test_plot_state_occupancy_returns_axes(bundle: SimulationBundle) -> None:
    _, ax = plt.subplots()
    returned = plot_state_occupancy(bundle, ax=ax)
    assert isinstance(returned, Axes)
    assert returned is ax
    # stackplot creates one PolyCollection per state
    assert len(ax.collections) == len(STATE_ORDER)


def test_plot_state_occupancy_creates_axes_when_none(bundle: SimulationBundle) -> None:
    ax = plot_state_occupancy(bundle, ax=None)
    assert isinstance(ax, Axes)


def test_plot_cumulative_departures_returns_axes(bundle: SimulationBundle) -> None:
    _, ax = plt.subplots()
    returned = plot_cumulative_departures(bundle, ax=ax)
    assert isinstance(returned, Axes)
    # one plotted line plus three overlay axvlines
    assert len(ax.lines) >= 1


def test_plot_cumulative_departures_default_axes(bundle: SimulationBundle) -> None:
    ax = plot_cumulative_departures(bundle)
    assert isinstance(ax, Axes)


def test_plot_household_trajectories(bundle: SimulationBundle) -> None:
    ids = [0, 5, 12]
    axes = plot_household_trajectories(bundle, household_ids=ids)
    assert len(axes) == len(ids)
    for ax in axes:
        assert isinstance(ax, Axes)


def test_plot_household_trajectories_with_supplied_axes(bundle: SimulationBundle) -> None:
    ids = [0, 1]
    _, axes = plt.subplots(len(ids), 1)
    returned = plot_household_trajectories(bundle, household_ids=ids, ax=list(axes))
    assert len(returned) == len(ids)


def test_plot_household_trajectories_single_household(bundle: SimulationBundle) -> None:
    axes = plot_household_trajectories(bundle, household_ids=[0])
    assert len(axes) == 1


def test_plot_household_trajectories_too_many(bundle: SimulationBundle) -> None:
    with pytest.raises(ValueError, match="at most"):
        plot_household_trajectories(bundle, household_ids=list(range(10)))


def test_plot_household_trajectories_empty(bundle: SimulationBundle) -> None:
    with pytest.raises(ValueError, match="at least one"):
        plot_household_trajectories(bundle, household_ids=[])


def test_plot_household_trajectories_unknown_id(bundle: SimulationBundle) -> None:
    with pytest.raises(ValueError, match="not found"):
        plot_household_trajectories(bundle, household_ids=[99999])


def test_plot_household_trajectories_axis_count_mismatch(bundle: SimulationBundle) -> None:
    _, ax = plt.subplots()
    with pytest.raises(ValueError, match="one entry per"):
        plot_household_trajectories(bundle, household_ids=[0, 1], ax=[ax])


def test_plot_emission_summary_returns_axes(bundle: SimulationBundle) -> None:
    _, ax = plt.subplots()
    returned = plot_emission_summary(bundle, ax=ax)
    assert isinstance(returned, Axes)
    # 3 metrics times at most 5 states = at most 15 bars; at least one per metric
    assert len(ax.patches) >= 3


def test_plot_emission_summary_default_axes(bundle: SimulationBundle) -> None:
    ax = plot_emission_summary(bundle)
    assert isinstance(ax, Axes)