src/iohmm_evac/dgp/population.py

# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""Synthesize a population of households with their static covariates."""

from __future__ import annotations

from dataclasses import dataclass

import numpy as np
from numpy.random import Generator

from iohmm_evac.params import PopulationParams
from iohmm_evac.types import FloatArray, IntArray, Zone

__all__ = ["Population", "synthesize_population", "zone_codes"]


@dataclass(frozen=True, slots=True)
class Population:
    """Static household covariates for the entire cohort."""

    distance: FloatArray
    """Distance to coast (km), shape (N,)."""

    vehicle: FloatArray
    """Vehicle access indicator in {0.0, 1.0}, shape (N,)."""

    risk: FloatArray
    """Standard-normal risk sensitivity, shape (N,)."""

    zone: IntArray
    """Encoded zone (0=A, 1=B, 2=C), shape (N,)."""

    destination: FloatArray
    """Destination distance if evacuating (km), shape (N,)."""

    @property
    def n(self) -> int:
        """Number of households."""
        return int(self.distance.shape[0])


def _truncated_normal(
    rng: Generator, mu: float, sigma: float, lo: float, hi: float, size: int
) -> FloatArray:
    """Sample a truncated normal by rejection sampling.

    Vectorized: oversamples in batches and refills any out-of-bounds entries
    until all are within ``[lo, hi]``. For the spec's ranges the acceptance
    probability is high, so a single oversample is usually enough.
    """
    out = np.empty(size, dtype=np.float64)
    filled = 0
    while filled < size:
        need = size - filled
        # Oversample by 25% to keep iterations low; this loop terminates
        # almost surely for any non-degenerate (lo < hi).
        draw = rng.normal(mu, sigma, size=int(need * 1.25) + 8)
        ok = draw[(draw >= lo) & (draw <= hi)]
        take = ok[: min(need, ok.shape[0])]
        out[filled : filled + take.shape[0]] = take
        filled += take.shape[0]
    return out


def zone_codes(distance: FloatArray, params: PopulationParams) -> IntArray:
    """Map coastal distances to zone codes (0=A, 1=B, 2=C)."""
    z = np.full(distance.shape[0], int(Zone.C == "C") and 2, dtype=np.int64)
    # Explicit boundaries: A if d < a_thr, B if a_thr <= d < b_thr, else C.
    z[:] = 2
    z[distance < params.zone_b_threshold] = 1
    z[distance < params.zone_a_threshold] = 0
    return z


def synthesize_population(n: int, rng: Generator, params: PopulationParams) -> Population:
    """Sample static covariates for ``n`` households independently."""
    distance = _truncated_normal(
        rng, params.distance_mu, params.distance_sigma, params.distance_lo, params.distance_hi, n
    )
    vehicle = (rng.random(size=n) < params.vehicle_p).astype(np.float64)
    risk = rng.normal(params.risk_mu, params.risk_sigma, size=n)
    zone = zone_codes(distance, params)
    destination = rng.uniform(params.dest_lo, params.dest_hi, size=n)
    return Population(
        distance=distance,
        vehicle=vehicle,
        risk=risk,
        zone=zone,
        destination=destination,
    )