src/mine_sim/routing.py

← Back to submission · View raw on GitHub

"""Static shortest-time routing, reachability self-check, and dispatch cost.

Routing is *static shortest-time per scenario*: one Dijkstra pass on
free-flow edge times (``distance_m / max_speed_kph``), recomputed whenever a
scenario closes or upgrades edges. Routes are frozen at construction time;
the simulation reads precomputed paths during travel and never re-runs
Dijkstra on the hot path.

Reachability is enforced loudly: if any required OD pair
(PARK<->LOAD_N, PARK<->LOAD_S, LOAD_N<->CRUSH, LOAD_S<->CRUSH) is
unreachable in a scenario, :func:`assert_reachable` raises rather than let
the simulation silently produce misleading results.

The *dispatch* (loader-choice) cost lives here too, as a pure function, so
the policy can be unit-tested without spinning up SimPy. The route a truck
follows is static; the *loader it chooses* is dynamic (depends on live
queue lengths) — these two concerns are kept distinct.
"""

from __future__ import annotations

from dataclasses import dataclass
from types import MappingProxyType
from typing import Mapping, Sequence

import networkx as nx

from mine_sim.topology import EdgeView, Topology

# ---------------------------------------------------------------------------
# Required OD pairs that must always be reachable (single source of truth,
# referenced by both the simulation and the tests).
# ---------------------------------------------------------------------------
REQUIRED_OD_PAIRS: tuple[tuple[str, str], ...] = (
    ("PARK", "LOAD_N"),
    ("LOAD_N", "PARK"),
    ("PARK", "LOAD_S"),
    ("LOAD_S", "PARK"),
    ("LOAD_N", "CRUSH"),
    ("CRUSH", "LOAD_N"),
    ("LOAD_S", "CRUSH"),
    ("CRUSH", "LOAD_S"),
)

#: The four cycle anchors between which routes are precomputed.
CYCLE_NODES: tuple[str, ...] = ("PARK", "LOAD_N", "LOAD_S", "CRUSH")


# ---------------------------------------------------------------------------
# Data classes
# ---------------------------------------------------------------------------
@dataclass(frozen=True)
class Route:
    """Pre-computed shortest-time path for one (origin, destination) pair."""

    origin: str
    destination: str
    edge_ids: tuple[str, ...]
    free_flow_time_min: float

    @property
    def is_trivial(self) -> bool:
        return self.origin == self.destination


@dataclass(frozen=True)
class RoutingTable:
    """All precomputed routes for one scenario, keyed by (origin, destination)."""

    routes: Mapping[tuple[str, str], Route]

    def get(self, origin: str, destination: str) -> Route | None:
        return self.routes.get((origin, destination))

    def require(self, origin: str, destination: str) -> Route:
        route = self.get(origin, destination)
        if route is None:
            raise KeyError(
                f"No precomputed route from {origin} to {destination}. "
                "Was the reachability self-check run?"
            )
        return route


class ReachabilityError(RuntimeError):
    """Raised loudly when a required OD pair is unreachable in a scenario."""


# ---------------------------------------------------------------------------
# Dispatch (loader-choice) policy — pure functions, no SimPy.
# ---------------------------------------------------------------------------
@dataclass(frozen=True)
class LoaderCandidate:
    """A loader option evaluated at dispatch time.

    ``travel_time_min`` is the (speed-scaled) free-flow time from the truck's
    current node to this loader; ``queue_len`` is the live count of trucks at
    the loader (in service + waiting).
    """

    loader_id: str
    travel_time_min: float
    queue_len: int
    mean_load_time_min: float

    @property
    def cost(self) -> float:
        return loader_choice_cost(
            self.travel_time_min, self.queue_len, self.mean_load_time_min
        )


def loader_choice_cost(
    travel_time_min: float,
    queue_len: int,
    mean_load_time_min: float,
) -> float:
    """Expected time-to-loaded for a candidate loader.

    ``cost = travel_to_loader + queue_len * mean_load_time + own_mean_load``.

    ``queue_len`` includes the truck currently being served, so it is the
    number of full load services this truck must wait through *plus* its own.
    """
    return (
        float(travel_time_min)
        + int(queue_len) * float(mean_load_time_min)
        + float(mean_load_time_min)
    )


def select_loader(candidates: Sequence[LoaderCandidate]) -> str:
    """Return the loader_id with minimum dispatch cost.

    Ties are broken deterministically by ascending ``loader_id`` (we sort by
    id first, so :func:`min` returns the lowest id among equal-cost options).
    """
    if not candidates:
        raise ValueError("select_loader requires at least one candidate")
    ordered = sorted(candidates, key=lambda c: c.loader_id)
    return min(ordered, key=lambda c: c.cost).loader_id


# ---------------------------------------------------------------------------
# Graph construction
# ---------------------------------------------------------------------------
def _build_directed_graph(topology: Topology) -> nx.DiGraph:
    """Build a DiGraph weighted by free-flow time (minutes).

    Closed edges are skipped entirely so Dijkstra cannot select them. If two
    edges share a (from, to) pair (none do in our data, but defensive) we
    keep the cheaper one.
    """
    graph: nx.DiGraph = nx.DiGraph()
    for node_id in topology.nodes:
        graph.add_node(node_id)
    for edge in topology.edges.values():
        if edge.closed:
            continue
        weight = edge.free_flow_time_min
        if weight == float("inf"):
            continue
        existing = graph.get_edge_data(edge.from_node, edge.to_node)
        if existing is not None and existing["weight"] <= weight:
            continue
        graph.add_edge(
            edge.from_node, edge.to_node, weight=weight, edge_id=edge.edge_id
        )
    return graph


def _path_to_edge_ids(graph: nx.DiGraph, path: list[str]) -> tuple[str, ...]:
    return tuple(graph[u][v]["edge_id"] for u, v in zip(path[:-1], path[1:]))


# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
def compute_routes(
    topology: Topology,
    sources: tuple[str, ...] | None = None,
    targets: tuple[str, ...] | None = None,
) -> RoutingTable:
    """Compute shortest-time routes between the cycle anchors.

    By default routes are computed between the four cycle anchors
    (``PARK``, ``LOAD_N``, ``LOAD_S``, ``CRUSH``) in both directions — all
    the simulation needs. ``sources``/``targets`` allow custom OD sets.
    """
    src = sources if sources is not None else CYCLE_NODES
    dst = targets if targets is not None else CYCLE_NODES

    graph = _build_directed_graph(topology)
    routes: dict[tuple[str, str], Route] = {}

    for origin in src:
        if origin not in graph:
            continue
        lengths, paths = nx.single_source_dijkstra(graph, origin, weight="weight")
        for destination in dst:
            if destination not in paths:
                continue
            routes[(origin, destination)] = Route(
                origin=origin,
                destination=destination,
                edge_ids=_path_to_edge_ids(graph, paths[destination]),
                free_flow_time_min=float(lengths[destination]),
            )

    return RoutingTable(routes=MappingProxyType(routes))


def assert_reachable(
    table: RoutingTable,
    required: tuple[tuple[str, str], ...] = REQUIRED_OD_PAIRS,
    scenario_id: str | None = None,
) -> None:
    """Raise :class:`ReachabilityError` if any required OD pair is missing.

    Lists *every* missing pair so a closure can be diagnosed in one pass.
    """
    missing = [pair for pair in required if pair not in table.routes]
    if not missing:
        return
    pretty = ", ".join(f"{a} -> {b}" for a, b in missing)
    prefix = f"Scenario '{scenario_id}': " if scenario_id else ""
    raise ReachabilityError(
        f"{prefix}required OD pairs unreachable in current topology: {pretty}"
    )


def free_flow_edge_time_min(edge: EdgeView) -> float:
    """Convenience re-export of :attr:`EdgeView.free_flow_time_min`."""
    return edge.free_flow_time_min


__all__ = [
    "CYCLE_NODES",
    "LoaderCandidate",
    "REQUIRED_OD_PAIRS",
    "ReachabilityError",
    "Route",
    "RoutingTable",
    "assert_reachable",
    "compute_routes",
    "free_flow_edge_time_min",
    "loader_choice_cost",
    "select_loader",
]