src/mine_sim/viz.py

← Back to submission · View raw on GitHub

"""Visualisation: topology.png and animation.gif from model data.

Both renderers read only the topology CSVs and the simulation event log — no
figures are hand-fabricated. ``render_topology`` draws the static graph with
capacity-1 (single-lane) edges highlighted; ``render_animation`` interpolates
truck positions from one replication's event log.

matplotlib is imported with the non-interactive Agg backend so rendering works
headless. Pillow (a transitive matplotlib dependency) provides the GIF writer.
"""

from __future__ import annotations

import math
from dataclasses import dataclass
from pathlib import Path

import matplotlib

matplotlib.use("Agg")  # headless backend; set before pyplot import

import matplotlib.pyplot as plt  # noqa: E402
import numpy as np  # noqa: E402
import pandas as pd  # noqa: E402
from matplotlib.animation import FuncAnimation, PillowWriter  # noqa: E402
from matplotlib.lines import Line2D  # noqa: E402

# ---------------------------------------------------------------------------
# Shared style
# ---------------------------------------------------------------------------
NODE_STYLE: dict[str, dict] = {
    "parking": {"color": "#1f77b4", "marker": "s", "size": 360, "label": "Parking"},
    "junction": {"color": "#7f7f7f", "marker": "o", "size": 200, "label": "Junction"},
    "load_ore": {"color": "#2ca02c", "marker": "^", "size": 420, "label": "Ore Loader"},
    "crusher": {"color": "#d62728", "marker": "*", "size": 620, "label": "Primary Crusher"},
    "waste_dump": {"color": "#8c564b", "marker": "X", "size": 320, "label": "Waste Dump"},
    "maintenance": {"color": "#9467bd", "marker": "P", "size": 320, "label": "Maintenance"},
}

CAP1_COLOR = "#e6194b"  # bold red for capacity-1 (bottleneck) edges
CAP_HI_COLOR = "#bfbfbf"  # light grey for high-capacity edges
TRUCK_LOADED_COLOR = "#ff7f0e"
TRUCK_EMPTY_COLOR = "#1f77b4"


def _load_topology_frames(data_dir: Path) -> tuple[pd.DataFrame, pd.DataFrame, dict]:
    nodes = pd.read_csv(data_dir / "nodes.csv").dropna(subset=["node_id"]).reset_index(drop=True)
    edges = pd.read_csv(data_dir / "edges.csv").dropna(subset=["edge_id"]).reset_index(drop=True)
    edges["capacity"] = edges["capacity"].astype(int)
    coords = {row.node_id: (float(row.x_m), float(row.y_m)) for row in nodes.itertuples()}
    return nodes, edges, coords


# ---------------------------------------------------------------------------
# Static topology diagram
# ---------------------------------------------------------------------------
def render_topology(data_dir: str | Path, out_path: str | Path) -> Path:
    """Render topology.png from nodes.csv + edges.csv. Returns the path."""
    data_dir = Path(data_dir)
    out_path = Path(out_path)
    out_path.parent.mkdir(parents=True, exist_ok=True)
    nodes, edges, coords = _load_topology_frames(data_dir)

    fig, ax = plt.subplots(figsize=(13, 10), dpi=150)
    ax.set_facecolor("#fafafa")

    cap1_edges = edges[edges["capacity"] <= 1]
    hi_cap_edges = edges[edges["capacity"] > 1]

    for row in hi_cap_edges.itertuples():
        if row.from_node not in coords or row.to_node not in coords:
            continue
        x0, y0 = coords[row.from_node]
        x1, y1 = coords[row.to_node]
        ax.annotate(
            "", xy=(x1, y1), xytext=(x0, y0),
            arrowprops=dict(arrowstyle="->", color=CAP_HI_COLOR, lw=1.0,
                            shrinkA=12, shrinkB=12, alpha=0.7),
            zorder=1,
        )

    offset_m = 55.0  # perpendicular offset so both directions are visible
    for row in cap1_edges.itertuples():
        if row.from_node not in coords or row.to_node not in coords:
            continue
        x0, y0 = coords[row.from_node]
        x1, y1 = coords[row.to_node]
        dx, dy = x1 - x0, y1 - y0
        length = math.hypot(dx, dy) or 1.0
        ox, oy = -dy / length * offset_m, dx / length * offset_m
        ax.annotate(
            "", xy=(x1 + ox, y1 + oy), xytext=(x0 + ox, y0 + oy),
            arrowprops=dict(arrowstyle="->", color=CAP1_COLOR, lw=2.4,
                            shrinkA=14, shrinkB=14),
            zorder=3,
        )
        ax.text(
            (x0 + x1) / 2.0 + ox, (y0 + y1) / 2.0 + oy, row.edge_id,
            fontsize=7, color=CAP1_COLOR, weight="bold", ha="center", va="center",
            bbox=dict(boxstyle="round,pad=0.18", fc="white", ec=CAP1_COLOR, lw=0.6, alpha=0.9),
            zorder=4,
        )

    for ntype, style in NODE_STYLE.items():
        subset = nodes[nodes["node_type"] == ntype]
        if subset.empty:
            continue
        ax.scatter(subset["x_m"], subset["y_m"], c=style["color"], marker=style["marker"],
                   s=style["size"], edgecolors="black", linewidths=0.8, zorder=5)

    for row in nodes.itertuples():
        ax.annotate(row.node_id, xy=(row.x_m, row.y_m), xytext=(8, 8),
                    textcoords="offset points", fontsize=9, weight="bold", zorder=6)

    node_handles = [
        Line2D([0], [0], marker=s["marker"], color="w", markerfacecolor=s["color"],
               markeredgecolor="black", markersize=10, label=s["label"])
        for s in NODE_STYLE.values()
    ]
    edge_handles = [
        Line2D([0], [0], color=CAP1_COLOR, lw=2.4, label="Capacity-1 edge (SimPy Resource)"),
        Line2D([0], [0], color=CAP_HI_COLOR, lw=1.0, label="High-capacity edge"),
    ]
    leg1 = ax.legend(handles=node_handles, title="Nodes", loc="upper left",
                     fontsize=9, title_fontsize=10, framealpha=0.95)
    ax.add_artist(leg1)
    ax.legend(handles=edge_handles, title="Edges", loc="lower right",
              fontsize=9, title_fontsize=10, framealpha=0.95)

    n_cap1 = len(cap1_edges)
    ax.set_title(
        f"Synthetic Mine Topology (benchmark 001)\n"
        f"{len(nodes)} nodes, {len(edges)} directed edges, "
        f"{n_cap1} capacity-1 segments highlighted",
        fontsize=13, weight="bold",
    )
    ax.set_xlabel("x (m)")
    ax.set_ylabel("y (m)")
    ax.grid(True, linestyle=":", alpha=0.4)
    ax.set_aspect("equal", adjustable="datalim")
    ax.text(0.99, 0.01,
            "WASTE & MAINT are excluded from active ore haulage (no truck traffic).",
            transform=ax.transAxes, fontsize=8, color="#444", ha="right", va="bottom",
            style="italic")

    fig.tight_layout()
    fig.savefig(out_path, dpi=150, bbox_inches="tight")
    plt.close(fig)
    return out_path


# ---------------------------------------------------------------------------
# Animation
# ---------------------------------------------------------------------------
@dataclass(frozen=True)
class _TruckTimeline:
    truck_id: str
    times: np.ndarray
    xs: np.ndarray
    ys: np.ndarray
    loaded: np.ndarray

    def position_at(self, t: float) -> tuple[float, float, bool]:
        if t <= self.times[0]:
            return float(self.xs[0]), float(self.ys[0]), bool(self.loaded[0])
        if t >= self.times[-1]:
            return float(self.xs[-1]), float(self.ys[-1]), bool(self.loaded[-1])
        idx = int(np.searchsorted(self.times, t, side="right")) - 1
        idx = max(0, min(idx, len(self.times) - 2))
        t0, t1 = self.times[idx], self.times[idx + 1]
        if t1 == t0:
            return float(self.xs[idx + 1]), float(self.ys[idx + 1]), bool(self.loaded[idx + 1])
        frac = (t - t0) / (t1 - t0)
        x = self.xs[idx] + frac * (self.xs[idx + 1] - self.xs[idx])
        y = self.ys[idx] + frac * (self.ys[idx + 1] - self.ys[idx])
        return float(x), float(y), bool(self.loaded[idx + 1])


def _event_slice(event_log_path: Path, scenario_id: str, replication: int) -> pd.DataFrame:
    df = pd.read_csv(event_log_path)
    mask = (df["scenario_id"] == scenario_id) & (df["replication"] == replication)
    sliced = df.loc[mask].copy()
    if sliced.empty:
        raise ValueError(
            f"No events for scenario_id={scenario_id!r} replication={replication} in "
            f"{event_log_path}. Available scenarios: {sorted(df['scenario_id'].unique())}"
        )
    sliced.sort_values("time_min", kind="stable", inplace=True)
    return sliced.reset_index(drop=True)


def _build_timelines(events: pd.DataFrame, coords: dict) -> list[_TruckTimeline]:
    timelines: list[_TruckTimeline] = []
    keep = events[events["location"].isin(coords.keys())].copy()
    for truck_id, grp in keep.groupby("truck_id", sort=True):
        grp_sorted = grp.sort_values("time_min", kind="stable")
        times = grp_sorted["time_min"].to_numpy(dtype=float)
        xs = np.array([coords[loc][0] for loc in grp_sorted["location"]], dtype=float)
        ys = np.array([coords[loc][1] for loc in grp_sorted["location"]], dtype=float)
        loaded = (grp_sorted["loaded"].astype(str).str.lower() == "true").to_numpy()
        timelines.append(_TruckTimeline(str(truck_id), times, xs, ys, loaded))
    return timelines


def _cumulative_dumps(events: pd.DataFrame) -> tuple[np.ndarray, np.ndarray]:
    dumps = events[events["event_type"] == "end_dump"].sort_values("time_min")
    times = dumps["time_min"].to_numpy(dtype=float)
    counts = np.arange(1, len(times) + 1, dtype=int)
    return times, counts


def _dumps_done_at(t: float, dump_times: np.ndarray, dump_counts: np.ndarray) -> int:
    if len(dump_times) == 0 or t < dump_times[0]:
        return 0
    idx = int(np.searchsorted(dump_times, t, side="right")) - 1
    idx = max(0, min(idx, len(dump_counts) - 1))
    return int(dump_counts[idx])


def _draw_static(ax, nodes: pd.DataFrame, edges: pd.DataFrame, coords: dict) -> None:
    for row in edges[edges["capacity"] > 1].itertuples():
        if row.from_node in coords and row.to_node in coords:
            x0, y0 = coords[row.from_node]
            x1, y1 = coords[row.to_node]
            ax.plot([x0, x1], [y0, y1], color=CAP_HI_COLOR, lw=1.0, alpha=0.6, zorder=1)
    for row in edges[edges["capacity"] <= 1].itertuples():
        if row.from_node in coords and row.to_node in coords:
            x0, y0 = coords[row.from_node]
            x1, y1 = coords[row.to_node]
            ax.plot([x0, x1], [y0, y1], color=CAP1_COLOR, lw=1.6, alpha=0.55, zorder=2)
    for ntype, style in NODE_STYLE.items():
        subset = nodes[nodes["node_type"] == ntype]
        if not subset.empty:
            ax.scatter(subset["x_m"], subset["y_m"], c=style["color"], marker=style["marker"],
                       s=style["size"], edgecolors="black", linewidths=0.6, zorder=3)
    for row in nodes.itertuples():
        ax.annotate(row.node_id, xy=(row.x_m, row.y_m), xytext=(7, 7),
                    textcoords="offset points", fontsize=7.5, weight="bold",
                    color="#222", zorder=4)


def render_animation(
    event_log_path: str | Path,
    data_dir: str | Path,
    out_path: str | Path,
    *,
    scenario_id: str = "baseline",
    replication: int = 0,
    shift_minutes: float = 480.0,
    n_frames: int = 144,
    fps: int = 12,
    payload_tonnes: float = 100.0,
) -> Path:
    """Render animation.gif for one (scenario, replication) from the event log."""
    event_log_path = Path(event_log_path)
    out_path = Path(out_path)
    out_path.parent.mkdir(parents=True, exist_ok=True)
    nodes, edges, coords = _load_topology_frames(Path(data_dir))
    events = _event_slice(event_log_path, scenario_id, replication)
    timelines = _build_timelines(events, coords)
    if not timelines:
        raise RuntimeError("No truck timelines could be built from the event log slice.")
    dump_times, dump_counts = _cumulative_dumps(events)

    fig, ax = plt.subplots(figsize=(11, 8.5), dpi=110)
    ax.set_facecolor("#fafafa")
    _draw_static(ax, nodes, edges, coords)

    empty_scatter = ax.scatter([], [], c=TRUCK_EMPTY_COLOR, marker="o", s=90,
                               edgecolors="black", linewidths=0.6, zorder=6)
    loaded_scatter = ax.scatter([], [], c=TRUCK_LOADED_COLOR, marker="o", s=90,
                                edgecolors="black", linewidths=0.6, zorder=6)
    hud = ax.text(0.01, 0.99, "", transform=ax.transAxes, fontsize=11, weight="bold",
                  ha="left", va="top",
                  bbox=dict(boxstyle="round,pad=0.35", fc="white", ec="#444", lw=0.8, alpha=0.9),
                  zorder=10)

    ax.set_title(
        f"Mine haulage animation — scenario={scenario_id}, replication={replication}\n"
        "empty (blue) vs loaded (orange) trucks; capacity-1 edges in red",
        fontsize=12, weight="bold",
    )
    ax.set_xlabel("x (m)")
    ax.set_ylabel("y (m)")
    ax.grid(True, linestyle=":", alpha=0.35)
    ax.set_aspect("equal", adjustable="datalim")

    node_handles = [
        Line2D([0], [0], marker=s["marker"], color="w", markerfacecolor=s["color"],
               markeredgecolor="black", markersize=9, label=s["label"])
        for s in NODE_STYLE.values()
    ]
    truck_handles = [
        Line2D([0], [0], marker="o", color="w", markerfacecolor=TRUCK_EMPTY_COLOR,
               markeredgecolor="black", markersize=9, label="Empty truck"),
        Line2D([0], [0], marker="o", color="w", markerfacecolor=TRUCK_LOADED_COLOR,
               markeredgecolor="black", markersize=9, label="Loaded truck"),
    ]
    leg1 = ax.legend(handles=node_handles, title="Nodes", loc="upper right",
                     fontsize=8, title_fontsize=9, framealpha=0.95)
    ax.add_artist(leg1)
    ax.legend(handles=truck_handles, title="Trucks", loc="lower right",
              fontsize=8, title_fontsize=9, framealpha=0.95)

    frame_times = np.linspace(0.0, shift_minutes, n_frames)

    def init():
        empty_scatter.set_offsets(np.empty((0, 2)))
        loaded_scatter.set_offsets(np.empty((0, 2)))
        hud.set_text("")
        return [empty_scatter, loaded_scatter, hud]

    def update(frame_idx: int):
        t = float(frame_times[frame_idx])
        empty_xy, loaded_xy = [], []
        for tl in timelines:
            x, y, loaded = tl.position_at(t)
            (loaded_xy if loaded else empty_xy).append((x, y))
        empty_scatter.set_offsets(np.array(empty_xy) if empty_xy else np.empty((0, 2)))
        loaded_scatter.set_offsets(np.array(loaded_xy) if loaded_xy else np.empty((0, 2)))
        dumps = _dumps_done_at(t, dump_times, dump_counts)
        hud.set_text(
            f"t = {t:6.1f} min  ({t / 60.0:4.2f} h)\n"
            f"completed dumps: {dumps:>3d}\n"
            f"cumulative tonnes: {dumps * payload_tonnes:>6.0f} t"
        )
        return [empty_scatter, loaded_scatter, hud]

    anim = FuncAnimation(fig, update, frames=n_frames, init_func=init,
                         interval=1000.0 / max(1, fps), blit=False, repeat=False)
    anim.save(str(out_path), writer=PillowWriter(fps=fps), dpi=110)
    plt.close(fig)
    return out_path


__all__ = ["render_animation", "render_topology"]