"""Scenario configuration data structures and YAML loading.
Scenarios are defined as YAML files with optional inheritance from a base
scenario (typically `baseline`). Override blocks let a scenario change:
* `fleet.truck_count` - number of trucks released at t=0
* `edge_overrides.<edge_id>` - capacity, max_speed_kph, closed flag
(used to model ramp upgrades and closures)
* `dump_point_overrides.<dump>` - crusher mean/sd dump time
* `node_overrides.<node_id>` - service-time overrides for the dump node
* `simulation.*` - shift length, replications, seed, warmup
* `routing.*`, `dispatching.*`,
`stochasticity.*` - tunable policy / noise parameters
The resulting :class:`ScenarioConfig` is an immutable dataclass; mutation is
forbidden so simulation runs cannot accidentally drift from the spec on disk.
"""
from __future__ import annotations
from dataclasses import dataclass, field, replace
from pathlib import Path
from typing import Any, Mapping
import yaml
# ---------------------------------------------------------------------------
# Defaults — used when a YAML file (or its ancestor chain) does not specify a
# value. Kept here, not in the YAML, so behaviour is explicit in code review.
# ---------------------------------------------------------------------------
DEFAULT_SHIFT_LENGTH_HOURS = 8
DEFAULT_REPLICATIONS = 30
DEFAULT_BASE_RANDOM_SEED = 12345
DEFAULT_WARMUP_MINUTES = 0
DEFAULT_TRAVEL_NOISE_CV = 0.10
DEFAULT_TRUCK_COUNT = 8
# ---------------------------------------------------------------------------
# Override dataclasses — each is a sparse patch applied on top of the CSV
# topology / fleet / dump-point data. Missing fields == "do not override".
# ---------------------------------------------------------------------------
@dataclass(frozen=True)
class EdgeOverride:
"""Patch applied to a single edge_id from edges.csv."""
capacity: int | None = None
max_speed_kph: float | None = None
closed: bool | None = None
@classmethod
def from_mapping(cls, raw: Mapping[str, Any]) -> "EdgeOverride":
return cls(
capacity=raw.get("capacity"),
max_speed_kph=raw.get("max_speed_kph"),
closed=raw.get("closed"),
)
@dataclass(frozen=True)
class DumpPointOverride:
"""Patch applied to a single dump point row from dump_points.csv."""
mean_dump_time_min: float | None = None
sd_dump_time_min: float | None = None
@classmethod
def from_mapping(cls, raw: Mapping[str, Any]) -> "DumpPointOverride":
return cls(
mean_dump_time_min=raw.get("mean_dump_time_min"),
sd_dump_time_min=raw.get("sd_dump_time_min"),
)
@dataclass(frozen=True)
class NodeOverride:
"""Patch applied to a single node row from nodes.csv."""
service_time_mean_min: float | None = None
service_time_sd_min: float | None = None
capacity: int | None = None
@classmethod
def from_mapping(cls, raw: Mapping[str, Any]) -> "NodeOverride":
return cls(
service_time_mean_min=raw.get("service_time_mean_min"),
service_time_sd_min=raw.get("service_time_sd_min"),
capacity=raw.get("capacity"),
)
# ---------------------------------------------------------------------------
# Policy / stochasticity blocks
# ---------------------------------------------------------------------------
@dataclass(frozen=True)
class SimulationParams:
shift_length_hours: int = DEFAULT_SHIFT_LENGTH_HOURS
replications: int = DEFAULT_REPLICATIONS
base_random_seed: int = DEFAULT_BASE_RANDOM_SEED
warmup_minutes: float = DEFAULT_WARMUP_MINUTES
@property
def shift_length_minutes(self) -> float:
return self.shift_length_hours * 60.0
@dataclass(frozen=True)
class RoutingParams:
objective: str = "shortest_time"
allow_bypass: bool = True
road_capacity_enabled: bool = True
@dataclass(frozen=True)
class DispatchingParams:
policy: str = "nearest_available_loader"
tie_breaker: str = "shortest_expected_cycle_time"
@dataclass(frozen=True)
class StochasticityParams:
loading_time_distribution: str = "normal_truncated"
dumping_time_distribution: str = "normal_truncated"
travel_time_noise_cv: float = DEFAULT_TRAVEL_NOISE_CV
@dataclass(frozen=True)
class FleetParams:
truck_count: int = DEFAULT_TRUCK_COUNT
# ---------------------------------------------------------------------------
# Top-level scenario record
# ---------------------------------------------------------------------------
@dataclass(frozen=True)
class ScenarioConfig:
"""Resolved scenario after inheritance and override merging.
`edge_overrides`, `dump_point_overrides`, and `node_overrides` are keyed
by their CSV ids (e.g. ``E03_UP``, ``D_CRUSH``, ``CRUSH``).
"""
scenario_id: str
description: str = ""
simulation: SimulationParams = field(default_factory=SimulationParams)
routing: RoutingParams = field(default_factory=RoutingParams)
dispatching: DispatchingParams = field(default_factory=DispatchingParams)
stochasticity: StochasticityParams = field(default_factory=StochasticityParams)
fleet: FleetParams = field(default_factory=FleetParams)
edge_overrides: Mapping[str, EdgeOverride] = field(default_factory=dict)
dump_point_overrides: Mapping[str, DumpPointOverride] = field(default_factory=dict)
node_overrides: Mapping[str, NodeOverride] = field(default_factory=dict)
inherits: str | None = None
# Helpers ----------------------------------------------------------------
def replication_seed(self, replication_index: int) -> int:
"""Per-replication seed = base_seed + replication_index."""
return self.simulation.base_random_seed + replication_index
def closed_edge_ids(self) -> tuple[str, ...]:
"""Edges marked closed by this scenario's overrides."""
return tuple(
edge_id
for edge_id, override in self.edge_overrides.items()
if override.closed is True
)
def with_overrides(self, **kwargs: Any) -> "ScenarioConfig":
"""Return a new config with the given top-level fields replaced."""
return replace(self, **kwargs)
# ---------------------------------------------------------------------------
# YAML loading with single-step inheritance
# ---------------------------------------------------------------------------
def _read_yaml(path: Path) -> dict[str, Any]:
with path.open("r", encoding="utf-8") as handle:
loaded = yaml.safe_load(handle) or {}
if not isinstance(loaded, dict):
raise ValueError(f"Scenario YAML at {path} must be a mapping at the top level")
return loaded
def _deep_merge(base: dict[str, Any], patch: Mapping[str, Any]) -> dict[str, Any]:
"""Recursively merge `patch` onto `base`, returning a new dict.
Mappings merge key-by-key; scalars and lists in `patch` overwrite `base`.
Neither input is mutated.
"""
out = dict(base)
for key, patch_value in patch.items():
base_value = out.get(key)
if isinstance(base_value, dict) and isinstance(patch_value, Mapping):
out[key] = _deep_merge(base_value, patch_value)
else:
out[key] = patch_value
return out
def _resolve_inherited(
raw: Mapping[str, Any],
scenarios_dir: Path,
seen: tuple[str, ...] = (),
) -> dict[str, Any]:
"""Walk the `inherits` chain and merge child onto parent."""
parent_id = raw.get("inherits")
if parent_id is None:
return dict(raw)
if parent_id in seen:
chain = " -> ".join((*seen, parent_id))
raise ValueError(f"Cyclic scenario inheritance detected: {chain}")
parent_path = scenarios_dir / f"{parent_id}.yaml"
if not parent_path.exists():
raise FileNotFoundError(
f"Scenario '{raw.get('scenario_id', '<unknown>')}' inherits from "
f"'{parent_id}', but {parent_path} does not exist."
)
parent_raw = _read_yaml(parent_path)
parent_resolved = _resolve_inherited(
parent_raw,
scenarios_dir,
seen=(*seen, raw.get("scenario_id", "<unknown>")),
)
# The child's `inherits` key is intentionally dropped after resolution so
# the merged dict reflects a fully flattened scenario.
child = {k: v for k, v in raw.items() if k != "inherits"}
merged = _deep_merge(parent_resolved, child)
merged["inherits"] = parent_id
return merged
def _build_simulation(raw: Mapping[str, Any]) -> SimulationParams:
return SimulationParams(
shift_length_hours=int(raw.get("shift_length_hours", DEFAULT_SHIFT_LENGTH_HOURS)),
replications=int(raw.get("replications", DEFAULT_REPLICATIONS)),
base_random_seed=int(raw.get("base_random_seed", DEFAULT_BASE_RANDOM_SEED)),
warmup_minutes=float(raw.get("warmup_minutes", DEFAULT_WARMUP_MINUTES)),
)
def _build_routing(raw: Mapping[str, Any]) -> RoutingParams:
return RoutingParams(
objective=str(raw.get("objective", "shortest_time")),
allow_bypass=bool(raw.get("allow_bypass", True)),
road_capacity_enabled=bool(raw.get("road_capacity_enabled", True)),
)
def _build_dispatching(raw: Mapping[str, Any]) -> DispatchingParams:
return DispatchingParams(
policy=str(raw.get("policy", "nearest_available_loader")),
tie_breaker=str(raw.get("tie_breaker", "shortest_expected_cycle_time")),
)
def _build_stochasticity(raw: Mapping[str, Any]) -> StochasticityParams:
return StochasticityParams(
loading_time_distribution=str(
raw.get("loading_time_distribution", "normal_truncated")
),
dumping_time_distribution=str(
raw.get("dumping_time_distribution", "normal_truncated")
),
travel_time_noise_cv=float(
raw.get("travel_time_noise_cv", DEFAULT_TRAVEL_NOISE_CV)
),
)
def _build_fleet(raw: Mapping[str, Any]) -> FleetParams:
return FleetParams(truck_count=int(raw.get("truck_count", DEFAULT_TRUCK_COUNT)))
def _build_overrides(
raw: Mapping[str, Any],
factory,
):
return {key: factory(value) for key, value in raw.items()}
def scenario_from_mapping(raw: Mapping[str, Any]) -> ScenarioConfig:
"""Convert a fully merged YAML mapping into a :class:`ScenarioConfig`."""
scenario_id = raw.get("scenario_id")
if not scenario_id:
raise ValueError("Scenario YAML missing required 'scenario_id'")
return ScenarioConfig(
scenario_id=str(scenario_id),
description=str(raw.get("description", "")),
simulation=_build_simulation(raw.get("simulation", {})),
routing=_build_routing(raw.get("routing", {})),
dispatching=_build_dispatching(raw.get("dispatching", {})),
stochasticity=_build_stochasticity(raw.get("stochasticity", {})),
fleet=_build_fleet(raw.get("fleet", {})),
edge_overrides=_build_overrides(
raw.get("edge_overrides", {}) or {}, EdgeOverride.from_mapping
),
dump_point_overrides=_build_overrides(
raw.get("dump_point_overrides", {}) or {}, DumpPointOverride.from_mapping
),
node_overrides=_build_overrides(
raw.get("node_overrides", {}) or {}, NodeOverride.from_mapping
),
inherits=raw.get("inherits"),
)
def load_scenario(path: str | Path) -> ScenarioConfig:
"""Load a single scenario YAML file, applying inheritance."""
yaml_path = Path(path)
if not yaml_path.exists():
raise FileNotFoundError(f"Scenario YAML not found: {yaml_path}")
raw = _read_yaml(yaml_path)
merged = _resolve_inherited(raw, scenarios_dir=yaml_path.parent)
return scenario_from_mapping(merged)
REQUIRED_SCENARIO_IDS: tuple[str, ...] = (
"baseline",
"trucks_4",
"trucks_12",
"ramp_upgrade",
"crusher_slowdown",
"ramp_closed",
"trucks_12_ramp_upgrade",
)
def load_all_scenarios(
scenarios_dir: str | Path,
required: tuple[str, ...] = REQUIRED_SCENARIO_IDS,
) -> dict[str, ScenarioConfig]:
"""Load every scenario in a directory, keyed by scenario_id.
Verifies each ``required`` scenario_id is present so that a typo in a
filename fails loudly rather than silently dropping a run.
"""
directory = Path(scenarios_dir)
if not directory.is_dir():
raise NotADirectoryError(f"Scenarios directory not found: {directory}")
scenarios: dict[str, ScenarioConfig] = {}
for yaml_path in sorted(directory.glob("*.yaml")):
config = load_scenario(yaml_path)
if config.scenario_id in scenarios:
raise ValueError(
f"Duplicate scenario_id '{config.scenario_id}' "
f"detected in {yaml_path}"
)
scenarios[config.scenario_id] = config
missing = tuple(s for s in required if s not in scenarios)
if missing:
raise ValueError(
f"Missing required scenario YAML(s): {missing}. "
f"Found: {tuple(scenarios)}"
)
return scenarios
__all__ = [
"DEFAULT_BASE_RANDOM_SEED",
"DEFAULT_REPLICATIONS",
"DEFAULT_SHIFT_LENGTH_HOURS",
"DEFAULT_TRAVEL_NOISE_CV",
"DEFAULT_TRUCK_COUNT",
"DEFAULT_WARMUP_MINUTES",
"DispatchingParams",
"DumpPointOverride",
"EdgeOverride",
"FleetParams",
"NodeOverride",
"REQUIRED_SCENARIO_IDS",
"RoutingParams",
"ScenarioConfig",
"SimulationParams",
"StochasticityParams",
"load_all_scenarios",
"load_scenario",
"scenario_from_mapping",
]
src/mine_sim/scenarios.py
← Back to submission · View raw on GitHub