src/mine_sim/rng.py

โ† Back to submission ยท View raw on GitHub

"""Reproducible random-number generation for replications.

The reproducibility contract is:

    All randomness derives from ``base_random_seed`` in the scenario YAML;
    the per-replication seed is ``base_seed + replication_index``.

This module is the single source of truth for that rule. Wherever the
simulation needs randomness, it asks this module for a generator, so two
runs of the same ``(scenario, replication_index)`` produce bit-identical
metric outputs regardless of run order.

Design notes
------------
* ``numpy.random.SeedSequence(seed).spawn(n)`` derives ``n`` *independent*
  child generators from the per-replication seed โ€” the recommended numpy
  pattern for splitting one entropy source into uncorrelated streams.
* Each stochastic source has its own named stream (``loading``,
  ``dumping``, ``edge_noise``, ``dispatch``, ``misc``). Independent streams
  mean a refactor that changes call ordering inside one stream does not
  shift the others. Appending a new stream is safe; reordering/renaming is
  a breaking change that must bump scenario seeds intentionally.

The module is deliberately small, pure-Python, and free of SimPy imports so
it can be unit-tested in isolation.
"""

from __future__ import annotations

from dataclasses import dataclass
from types import MappingProxyType
from typing import Final, Iterable, Mapping

import numpy as np

# ---------------------------------------------------------------------------
# Public constants
# ---------------------------------------------------------------------------
#: Ordered stream names. Order is part of the reproducibility contract:
#: spawning relies on positional indices into the SeedSequence children.
STREAM_NAMES: Final[tuple[str, ...]] = (
    "loading",     # truncated-normal load times at LOAD_N / LOAD_S
    "dumping",     # truncated-normal dump times at CRUSH
    "edge_noise",  # lognormal travel-time multiplier per edge traversal
    "dispatch",    # reserved for randomized dispatching / tie-breaks
    "misc",        # reserved for ad-hoc draws
)

#: Maximum rejection-sampling attempts for truncated normals before a final
#: clamp. High enough that realistic parameters terminate in one or two draws.
_TRUNCATION_MAX_ATTEMPTS: Final[int] = 64

#: Default lower bound for truncated load/dump samples: ``max(0.1, sample)``.
DEFAULT_TRUNCATION_FLOOR: Final[float] = 0.1


def replication_seed(base_seed: int, replication_index: int) -> int:
    """Return the per-replication seed: ``base_seed + replication_index``.

    Both arguments must be non-negative so the resulting seed is itself
    non-negative and stable across platforms.
    """
    if base_seed < 0:
        raise ValueError(f"base_seed must be >= 0, got {base_seed}")
    if replication_index < 0:
        raise ValueError(
            f"replication_index must be >= 0, got {replication_index}"
        )
    return int(base_seed) + int(replication_index)


@dataclass(frozen=True)
class ReplicationRNG:
    """Immutable bundle of independent generators for one replication.

    Construct via :func:`make_replication_rng`. The ``streams`` mapping is
    wrapped in ``MappingProxyType`` so callers cannot mutate it.
    """

    seed: int
    base_seed: int
    replication_index: int
    streams: Mapping[str, np.random.Generator]

    def __getitem__(self, name: str) -> np.random.Generator:
        try:
            return self.streams[name]
        except KeyError as exc:
            raise KeyError(
                f"Unknown RNG stream {name!r}. Known: {tuple(self.streams)}"
            ) from exc

    def __contains__(self, name: object) -> bool:
        return name in self.streams

    @property
    def stream_names(self) -> tuple[str, ...]:
        return tuple(self.streams)


def make_replication_rng(
    base_seed: int,
    replication_index: int,
    stream_names: Iterable[str] = STREAM_NAMES,
) -> ReplicationRNG:
    """Build a :class:`ReplicationRNG` for one replication, deterministically."""
    names = tuple(stream_names)
    if not names:
        raise ValueError("stream_names must be a non-empty iterable")
    if len(set(names)) != len(names):
        raise ValueError(f"stream_names must be unique, got {names}")

    seed = replication_seed(base_seed, replication_index)
    sequence = np.random.SeedSequence(seed)
    children = sequence.spawn(len(names))
    streams = MappingProxyType(
        {
            name: np.random.default_rng(child)
            for name, child in zip(names, children, strict=True)
        }
    )
    return ReplicationRNG(
        seed=seed,
        base_seed=int(base_seed),
        replication_index=int(replication_index),
        streams=streams,
    )


def truncated_normal(
    rng: np.random.Generator,
    mean: float,
    sd: float,
    minimum: float = DEFAULT_TRUNCATION_FLOOR,
) -> float:
    """Draw from ``N(mean, sd)`` truncated below at ``minimum``.

    Rejection sampling against ``[minimum, +inf)`` so the conditional
    distribution matches the truncated normal density. After
    :data:`_TRUNCATION_MAX_ATTEMPTS` rejections we fall back to a clamp,
    which only triggers for pathological parameters outside our scenarios.

    ``sd <= 0`` returns ``max(minimum, mean)`` deterministically.
    """
    if sd <= 0:
        return max(float(minimum), float(mean))

    for _ in range(_TRUNCATION_MAX_ATTEMPTS):
        sample = float(rng.normal(loc=mean, scale=sd))
        if sample >= minimum:
            return sample
    return max(float(minimum), float(rng.normal(loc=mean, scale=sd)))


def lognormal_noise_multiplier(rng: np.random.Generator, cv: float) -> float:
    """Lognormal multiplier with mean 1 and coefficient of variation ``cv``.

    For ``X = exp(N(mu, sigma))`` with ``E[X] = 1`` and ``CV[X] = cv``::

        sigma^2 = ln(1 + cv^2)
        mu      = -sigma^2 / 2

    ``cv == 0`` short-circuits to 1.0 (no draw consumed) so seed alignment
    stays intuitive when noise is disabled.
    """
    if cv < 0:
        raise ValueError(f"cv must be non-negative, got {cv}")
    if cv == 0:
        return 1.0
    sigma_sq = float(np.log1p(cv * cv))
    sigma = float(np.sqrt(sigma_sq))
    mu = -0.5 * sigma_sq
    return float(np.exp(rng.normal(loc=mu, scale=sigma)))


__all__ = [
    "DEFAULT_TRUNCATION_FLOOR",
    "ReplicationRNG",
    "STREAM_NAMES",
    "lognormal_noise_multiplier",
    "make_replication_rng",
    "replication_seed",
    "truncated_normal",
]