# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""Predefined scenarios that produce :class:`SimulationConfig` instances."""
from __future__ import annotations
from collections.abc import Callable
from dataclasses import replace
from iohmm_evac.params import (
FeedbackParams,
PopulationParams,
SimulationConfig,
TimelineParams,
)
__all__ = ["SCENARIO_BUILDERS", "build_scenario", "list_scenarios"]
def _baseline() -> SimulationConfig:
return SimulationConfig()
def _early_warning() -> SimulationConfig:
base = SimulationConfig()
return replace(base, timeline=TimelineParams(voluntary_hour=48, mandatory_hour=72))
def _targeted_messaging() -> SimulationConfig:
base = SimulationConfig()
return replace(
base,
population=PopulationParams(targeted_zone_multiplier=1.5),
)
def _contraflow() -> SimulationConfig:
base = SimulationConfig()
return replace(base, feedback=FeedbackParams(n_cap=2500))
SCENARIO_BUILDERS: dict[str, Callable[[], SimulationConfig]] = {
"baseline": _baseline,
"early-warning": _early_warning,
"targeted-messaging": _targeted_messaging,
"contraflow": _contraflow,
}
def list_scenarios() -> list[str]:
"""Return the registered scenario names in deterministic order."""
return sorted(SCENARIO_BUILDERS.keys())
def build_scenario(name: str) -> SimulationConfig:
"""Build a :class:`SimulationConfig` for the named scenario."""
if name not in SCENARIO_BUILDERS:
msg = f"Unknown scenario: {name!r}. Known: {list_scenarios()}"
raise ValueError(msg)
return SCENARIO_BUILDERS[name]()