tests/test_report_loader.py

# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
from __future__ import annotations

from pathlib import Path

import numpy as np
import pandas as pd
import pytest

from iohmm_evac.dgp.simulator import simulate
from iohmm_evac.io import write_results
from iohmm_evac.params import SimulationConfig
from iohmm_evac.report.loader import load_bundle


def _write_tiny(tmp_path: Path) -> Path:
    config = SimulationConfig(n_households=20, n_hours=8, seed=0)
    rng = np.random.default_rng(config.seed)
    result = simulate(config, rng)
    out = tmp_path / "tiny.parquet"
    write_results(result, out)
    return out


def test_load_bundle_round_trip(tmp_path: Path) -> None:
    out = _write_tiny(tmp_path)
    bundle = load_bundle(out)
    assert isinstance(bundle.observations, pd.DataFrame)
    assert isinstance(bundle.population, pd.DataFrame)
    assert isinstance(bundle.timeline, pd.DataFrame)
    assert isinstance(bundle.config, dict)
    assert bundle.n_households == 20
    assert bundle.t_landfall == 8
    assert {"household_id", "t", "state", "departure", "displacement", "comm_count"}.issubset(
        set(bundle.observations.columns)
    )
    assert bundle.config["n_hours"] == 8


def test_load_bundle_missing_population(tmp_path: Path) -> None:
    out = _write_tiny(tmp_path)
    (tmp_path / "tiny.population.parquet").unlink()
    with pytest.raises(FileNotFoundError, match="population"):
        load_bundle(out)


def test_load_bundle_missing_timeline(tmp_path: Path) -> None:
    out = _write_tiny(tmp_path)
    (tmp_path / "tiny.timeline.parquet").unlink()
    with pytest.raises(FileNotFoundError, match="timeline"):
        load_bundle(out)


def test_load_bundle_missing_config(tmp_path: Path) -> None:
    out = _write_tiny(tmp_path)
    (tmp_path / "tiny.config.toml").unlink()
    with pytest.raises(FileNotFoundError, match="config"):
        load_bundle(out)


def test_load_bundle_missing_observations(tmp_path: Path) -> None:
    out = _write_tiny(tmp_path)
    out.unlink()
    with pytest.raises(FileNotFoundError, match="observations"):
        load_bundle(out)