src/mine_sim/cli.py

← Back to submission · View raw on GitHub

"""Command-line interface for the mine throughput simulation.

Four subcommands::

    python -m mine_sim run <scenario_id>   # one scenario (smoke testing)
    python -m mine_sim run-all             # every required scenario (the deliverable)
    python -m mine_sim list                # list available scenarios
    python -m mine_sim render              # topology.png + animation.gif

``run`` and ``run-all`` write the three flat artefacts — ``results.csv``,
``event_log.csv``, ``summary.json`` — directly into ``--output-dir`` (default
the submission root, ``.``). All replications feed the metrics and CIs; the
event log defaults to ``--event-log-scope first`` (replication 0 of each
scenario only) so the trace stays small and inspectable.

The CLI is side-effect-only at the boundary: every decision lives in the
orchestration / aggregation / io modules, so the same calls can be driven
from a notebook or test.
"""

from __future__ import annotations

import argparse
import sys
import time
from dataclasses import replace
from pathlib import Path
from typing import Sequence

from mine_sim.aggregate import RunSummary, aggregate_run
from mine_sim.io_writers import (
    collect_events,
    write_event_log_csv,
    write_results_csv,
    write_run_summary_json,
)
from mine_sim.narrative import DEFAULT_BENCHMARK_ID
from mine_sim.scenario_runner import (
    MultiScenarioRunResult,
    ReplicationProgress,
    run_all_scenarios,
    run_scenario,
)
from mine_sim.scenarios import (
    REQUIRED_SCENARIO_IDS,
    ScenarioConfig,
    load_all_scenarios,
    load_scenario,
)

DEFAULT_DATA_DIR = Path("data")
DEFAULT_SCENARIOS_DIR = Path("data") / "scenarios"
#: Default output directory: the submission root, so the five deliverables
#: sit co-located at the top level where the harness looks for them.
DEFAULT_OUTPUT_DIR = Path(".")
DEFAULT_SHIFT_LENGTH_HOURS = 8.0


# ---------------------------------------------------------------------------
# Argument parsing
# ---------------------------------------------------------------------------
def build_parser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser(
        prog="python -m mine_sim",
        description=(
            "SimPy mine throughput simulation. 'run-all' produces the "
            "submission deliverables; 'run' is for single-scenario smoke tests."
        ),
    )
    subparsers = parser.add_subparsers(dest="command", required=True)
    _add_run_parser(subparsers)
    _add_run_all_parser(subparsers)
    _add_list_parser(subparsers)
    _add_render_parser(subparsers)
    return parser


def _add_common_run_args(sub: argparse.ArgumentParser) -> None:
    sub.add_argument("--data-dir", type=Path, default=DEFAULT_DATA_DIR,
                     help="Directory with input CSVs. Default: ./data")
    sub.add_argument("--scenarios-dir", type=Path, default=DEFAULT_SCENARIOS_DIR,
                     help="Directory with scenario YAMLs. Default: ./data/scenarios")
    sub.add_argument("--output-dir", type=Path, default=DEFAULT_OUTPUT_DIR,
                     help="Where to write results.csv/event_log.csv/summary.json. "
                          "Default: . (submission root)")
    sub.add_argument("--reps", type=int, default=None,
                     help="Override replication count (smoke tests). Default: 30 from YAML.")
    sub.add_argument("--event-log-scope", choices=("first", "all"), default="first",
                     help="Which replications to include in event_log.csv: 'first' "
                          "(rep 0 of each scenario, default) or 'all'.")
    sub.add_argument("--quiet", action="store_true",
                     help="Suppress per-replication progress output.")


def _add_run_parser(subparsers) -> None:
    sub = subparsers.add_parser(
        "run", help="Run a single scenario (30 reps by default).",
        description="Run one scenario and write the three flat artefacts to --output-dir.",
    )
    sub.add_argument("scenario_id", help="Scenario ID (e.g. 'baseline').")
    _add_common_run_args(sub)


def _add_run_all_parser(subparsers) -> None:
    sub = subparsers.add_parser(
        "run-all", help="Run every required scenario (30 reps each by default).",
        description="Run all required scenarios and write the flat deliverables to --output-dir.",
    )
    sub.add_argument("--scenario-ids", type=str, default=None,
                     help="Optional comma-separated subset (default: all required scenarios).")
    _add_common_run_args(sub)


def _add_list_parser(subparsers) -> None:
    sub = subparsers.add_parser("list", help="List available scenario IDs.")
    sub.add_argument("--scenarios-dir", type=Path, default=DEFAULT_SCENARIOS_DIR,
                     help="Directory with scenario YAMLs. Default: ./data/scenarios")


def _add_render_parser(subparsers) -> None:
    sub = subparsers.add_parser(
        "render", help="Render topology.png and animation.gif.",
        description="Render the static topology and an animation from event_log.csv.",
    )
    sub.add_argument("--data-dir", type=Path, default=DEFAULT_DATA_DIR)
    sub.add_argument("--event-log", type=Path, default=DEFAULT_OUTPUT_DIR / "event_log.csv",
                     help="event_log.csv to animate. Default: ./event_log.csv")
    sub.add_argument("--scenario", type=str, default="baseline")
    sub.add_argument("--replication", type=int, default=0)
    sub.add_argument("--topology-out", type=Path, default=DEFAULT_OUTPUT_DIR / "topology.png")
    sub.add_argument("--animation-out", type=Path, default=DEFAULT_OUTPUT_DIR / "animation.gif")
    sub.add_argument("--frames", type=int, default=144)
    sub.add_argument("--fps", type=int, default=12)
    sub.add_argument("--shift-minutes", type=float, default=480.0)
    sub.add_argument("--skip-topology", action="store_true")
    sub.add_argument("--skip-animation", action="store_true")


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _parse_scenario_ids(raw: str | None) -> tuple[str, ...] | None:
    if raw is None or raw.strip() == "":
        return None
    ids = tuple(token.strip() for token in raw.split(",") if token.strip())
    return ids or None


def _override_reps(scenario: ScenarioConfig, reps: int | None) -> ScenarioConfig:
    if reps is None:
        return scenario
    if reps <= 0:
        raise SystemExit(f"--reps must be > 0 (got {reps}).")
    return replace(scenario, simulation=replace(scenario.simulation, replications=reps))


def _make_progress_printer(quiet: bool):
    if quiet:
        return None
    start = [time.monotonic()]

    def _printer(event: ReplicationProgress) -> None:
        elapsed = time.monotonic() - start[0]
        label = f"[{event.scenario_index + 1}/{event.scenario_total}] {event.scenario_id}"
        m = event.result.metrics
        sys.stdout.write(
            f"  {label} rep {event.replication_index + 1}/{event.replication_total} "
            f"tonnes={m.total_tonnes_delivered:.0f} tph={m.tonnes_per_hour:.1f} "
            f"({elapsed:.1f}s)\n"
        )
        sys.stdout.flush()

    return _printer


def _write_flat_outputs(
    multi: MultiScenarioRunResult,
    run_summary: RunSummary,
    output_dir: Path,
    *,
    event_log_scope: str,
    expected_scenario_ids: Sequence[str] | None,
    expected_replications: int | None,
) -> tuple[Path, Path, Path]:
    """Write results.csv, event_log.csv, summary.json flat into ``output_dir``."""
    output_dir.mkdir(parents=True, exist_ok=True)
    all_reps = list(multi.all_replications())

    results_path = write_results_csv(all_reps, output_dir / "results.csv")
    events = collect_events(all_reps, scope=event_log_scope)
    event_log_path = write_event_log_csv(events, output_dir / "event_log.csv")
    summary_path = write_run_summary_json(
        run_summary,
        output_dir / "summary.json",
        expected_scenario_ids=expected_scenario_ids,
        expected_replications=expected_replications,
        expected_shift_length_hours=DEFAULT_SHIFT_LENGTH_HOURS,
    )
    return results_path, event_log_path, summary_path


# ---------------------------------------------------------------------------
# Command handlers
# ---------------------------------------------------------------------------
def cmd_run(args: argparse.Namespace) -> int:
    scenario_yaml = args.scenarios_dir / f"{args.scenario_id}.yaml"
    if not scenario_yaml.exists():
        sys.stderr.write(f"Scenario YAML not found: {scenario_yaml}\n")
        return 2

    scenario = _override_reps(load_scenario(scenario_yaml), args.reps)
    progress = _make_progress_printer(args.quiet)

    sys.stdout.write(
        f"Running scenario '{scenario.scenario_id}' "
        f"({scenario.simulation.replications} reps) -> {args.output_dir}\n"
    )
    sys.stdout.flush()

    start = time.monotonic()
    scenario_result = run_scenario(scenario, args.data_dir, progress=progress)
    elapsed = time.monotonic() - start

    multi = MultiScenarioRunResult(
        results={scenario.scenario_id: scenario_result}  # type: ignore[arg-type]
    )
    run_summary = aggregate_run({scenario.scenario_id: scenario_result.replications})
    results_path, event_log_path, summary_path = _write_flat_outputs(
        multi, run_summary, args.output_dir,
        event_log_scope=args.event_log_scope,
        expected_scenario_ids=None,
        expected_replications=(30 if args.reps is None else None),
    )

    summary = run_summary.scenarios[scenario.scenario_id]
    sys.stdout.write(
        f"\nScenario '{summary.scenario_id}' complete in {elapsed:.1f}s.\n"
        f"  total_tonnes: {summary.total_tonnes_delivered.mean:.0f} "
        f"[{summary.total_tonnes_delivered.ci95_low:.0f}, "
        f"{summary.total_tonnes_delivered.ci95_high:.0f}]\n"
        f"  tonnes/hour:  {summary.tonnes_per_hour.mean:.1f}\n"
        f"  artefacts: {results_path}, {event_log_path}, {summary_path}\n"
    )
    return 0


def cmd_run_all(args: argparse.Namespace) -> int:
    scenario_ids = _parse_scenario_ids(args.scenario_ids)
    requested_ids = scenario_ids if scenario_ids is not None else REQUIRED_SCENARIO_IDS

    sys.stdout.write(
        f"Running {len(requested_ids)} scenarios from {args.scenarios_dir} "
        f"-> {args.output_dir}\n"
    )
    sys.stdout.flush()

    scenarios_map = load_all_scenarios(args.scenarios_dir, required=requested_ids)
    scenarios_map = {sid: _override_reps(cfg, args.reps) for sid, cfg in scenarios_map.items()}

    progress = _make_progress_printer(args.quiet)
    start = time.monotonic()
    multi = run_all_scenarios(
        scenarios_map, data_dir=args.data_dir,
        scenario_ids=requested_ids, progress=progress,
    )
    elapsed = time.monotonic() - start

    run_summary = aggregate_run({sid: r.replications for sid, r in multi.results.items()})
    expected_replications = 30 if args.reps is None else None
    expected_scenario_ids = REQUIRED_SCENARIO_IDS if scenario_ids is None else None
    results_path, event_log_path, summary_path = _write_flat_outputs(
        multi, run_summary, args.output_dir,
        event_log_scope=args.event_log_scope,
        expected_scenario_ids=expected_scenario_ids,
        expected_replications=expected_replications,
    )

    _print_run_summary(run_summary, results_path, event_log_path, summary_path, elapsed)
    return 0


def cmd_list(args: argparse.Namespace) -> int:
    scenarios_dir = args.scenarios_dir
    if not scenarios_dir.is_dir():
        sys.stderr.write(f"Scenarios directory not found: {scenarios_dir}\n")
        return 2
    yaml_files = sorted(scenarios_dir.glob("*.yaml"))
    if not yaml_files:
        sys.stderr.write(f"No scenario YAMLs found in {scenarios_dir}\n")
        return 2
    sys.stdout.write(f"Scenarios in {scenarios_dir}:\n")
    for yaml_path in yaml_files:
        try:
            scenario = load_scenario(yaml_path)
        except Exception as exc:  # pragma: no cover - defensive UX
            sys.stdout.write(f"  - {yaml_path.stem}: <error: {exc.__class__.__name__}>\n")
            continue
        marker = " *" if scenario.scenario_id in REQUIRED_SCENARIO_IDS else "  "
        sys.stdout.write(
            f"  {marker}{scenario.scenario_id:<24} "
            f"reps={scenario.simulation.replications:<3} "
            f"trucks={scenario.fleet.truck_count:<3} {scenario.description}\n"
        )
    sys.stdout.write("\n  * = required scenario (default for run-all)\n")
    return 0


def cmd_render(args: argparse.Namespace) -> int:
    # Imported lazily so 'run'/'run-all' do not pay the matplotlib import cost.
    from mine_sim.viz import render_animation, render_topology

    if not args.skip_topology:
        path = render_topology(args.data_dir, args.topology_out)
        sys.stdout.write(f"Wrote {path}\n")
    if not args.skip_animation:
        if not Path(args.event_log).exists():
            sys.stderr.write(
                f"Event log not found: {args.event_log}. Run 'run' or 'run-all' first, "
                "or pass --skip-animation.\n"
            )
            return 2
        path = render_animation(
            args.event_log, args.data_dir, args.animation_out,
            scenario_id=args.scenario, replication=args.replication,
            shift_minutes=args.shift_minutes, n_frames=args.frames, fps=args.fps,
        )
        sys.stdout.write(f"Wrote {path}\n")
    return 0


def _print_run_summary(summary, results_path, event_log_path, summary_path, elapsed) -> None:
    sys.stdout.write(
        f"\nrun-all complete in {elapsed:.1f}s ({len(summary.scenarios)} scenarios).\n"
    )
    header = f"{'scenario_id':<24} {'tonnes_mean':>12} {'tph_mean':>9} {'tph_ci95':>20}"
    sys.stdout.write(header + "\n" + "-" * len(header) + "\n")
    for scenario_id, s in summary.scenarios.items():
        sys.stdout.write(
            f"{scenario_id:<24} {s.total_tonnes_delivered.mean:>12.0f} "
            f"{s.tonnes_per_hour.mean:>9.1f} "
            f"[{s.tonnes_per_hour.ci95_low:>7.1f}, {s.tonnes_per_hour.ci95_high:>7.1f}]\n"
        )
    sys.stdout.write(f"\nArtefacts:\n  {results_path}\n  {event_log_path}\n  {summary_path}\n")
    sys.stdout.flush()


def main(argv: Sequence[str] | None = None) -> int:
    parser = build_parser()
    args = parser.parse_args(argv)
    if args.command == "run":
        return cmd_run(args)
    if args.command == "run-all":
        return cmd_run_all(args)
    if args.command == "list":
        return cmd_list(args)
    if args.command == "render":
        return cmd_render(args)
    parser.error(f"Unknown command: {args.command}")
    return 2  # pragma: no cover


__all__ = [
    "DEFAULT_DATA_DIR",
    "DEFAULT_OUTPUT_DIR",
    "DEFAULT_SCENARIOS_DIR",
    "build_parser",
    "cmd_list",
    "cmd_render",
    "cmd_run",
    "cmd_run_all",
    "main",
]