# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""Synthesize a population of households with their static covariates."""
from __future__ import annotations
from dataclasses import dataclass
import numpy as np
from numpy.random import Generator
from iohmm_evac.params import PopulationParams
from iohmm_evac.types import FloatArray, IntArray, Zone
__all__ = ["Population", "synthesize_population", "zone_codes"]
@dataclass(frozen=True, slots=True)
class Population:
"""Static household covariates for the entire cohort."""
distance: FloatArray
"""Distance to coast (km), shape (N,)."""
vehicle: FloatArray
"""Vehicle access indicator in {0.0, 1.0}, shape (N,)."""
risk: FloatArray
"""Standard-normal risk sensitivity, shape (N,)."""
zone: IntArray
"""Encoded zone (0=A, 1=B, 2=C), shape (N,)."""
destination: FloatArray
"""Destination distance if evacuating (km), shape (N,)."""
@property
def n(self) -> int:
"""Number of households."""
return int(self.distance.shape[0])
def _truncated_normal(
rng: Generator, mu: float, sigma: float, lo: float, hi: float, size: int
) -> FloatArray:
"""Sample a truncated normal by rejection sampling.
Vectorized: oversamples in batches and refills any out-of-bounds entries
until all are within ``[lo, hi]``. For the spec's ranges the acceptance
probability is high, so a single oversample is usually enough.
"""
out = np.empty(size, dtype=np.float64)
filled = 0
while filled < size:
need = size - filled
# Oversample by 25% to keep iterations low; this loop terminates
# almost surely for any non-degenerate (lo < hi).
draw = rng.normal(mu, sigma, size=int(need * 1.25) + 8)
ok = draw[(draw >= lo) & (draw <= hi)]
take = ok[: min(need, ok.shape[0])]
out[filled : filled + take.shape[0]] = take
filled += take.shape[0]
return out
def zone_codes(distance: FloatArray, params: PopulationParams) -> IntArray:
"""Map coastal distances to zone codes (0=A, 1=B, 2=C)."""
z = np.full(distance.shape[0], int(Zone.C == "C") and 2, dtype=np.int64)
# Explicit boundaries: A if d < a_thr, B if a_thr <= d < b_thr, else C.
z[:] = 2
z[distance < params.zone_b_threshold] = 1
z[distance < params.zone_a_threshold] = 0
return z
def synthesize_population(n: int, rng: Generator, params: PopulationParams) -> Population:
"""Sample static covariates for ``n`` households independently."""
distance = _truncated_normal(
rng, params.distance_mu, params.distance_sigma, params.distance_lo, params.distance_hi, n
)
vehicle = (rng.random(size=n) < params.vehicle_p).astype(np.float64)
risk = rng.normal(params.risk_mu, params.risk_sigma, size=n)
zone = zone_codes(distance, params)
destination = rng.uniform(params.dest_lo, params.dest_hi, size=n)
return Population(
distance=distance,
vehicle=vehicle,
risk=risk,
zone=zone,
destination=destination,
)