734 lines
32 KiB
Python
734 lines
32 KiB
Python
import json
|
||
import re
|
||
import threading
|
||
import time
|
||
from dataclasses import asdict, dataclass
|
||
from datetime import UTC, datetime
|
||
from pathlib import Path
|
||
from typing import Any
|
||
|
||
from pipeline.config import (
|
||
PipelineConfig,
|
||
read_public_key,
|
||
resolve_api_key,
|
||
resolve_ssh_private_key,
|
||
)
|
||
from pipeline.remote import (
|
||
RemoteCommandError,
|
||
SshTarget,
|
||
fetch_directory,
|
||
is_process_alive,
|
||
read_remote_file,
|
||
rsync_file,
|
||
rsync_project,
|
||
run_detached,
|
||
shell_join,
|
||
ssh,
|
||
ssh_output,
|
||
tail_log,
|
||
wait_for_ssh,
|
||
)
|
||
from pipeline.vast_api import VastApiClient, VastApiError, VastInstance
|
||
|
||
|
||
# ── data classes ─────────────────────────────────────────────────────
|
||
|
||
# Snapshot of a pipeline run written to disk for traceability
|
||
@dataclass(slots=True)
|
||
class RunManifest:
|
||
created_at: str
|
||
config_paths: list[str]
|
||
instance_id: int | None
|
||
offer_id: int | None
|
||
ssh_host: str | None
|
||
ssh_port: int | None
|
||
status: str
|
||
remote_workspace: str | None
|
||
|
||
|
||
# CLI flags controlling how the pipeline run behaves
|
||
@dataclass
|
||
class RunOptions:
|
||
download_data: bool = False
|
||
send_cropped: bool = False
|
||
keep_on_failure: bool = False
|
||
dry_run: bool = False
|
||
select_offer: bool = False
|
||
select_template: bool = False
|
||
template_hash: str | None = None
|
||
sort_mode: str | None = None
|
||
region: str | None = None
|
||
price_cap: float | None = None
|
||
list_regions: bool = False
|
||
use_gpu: bool = True
|
||
|
||
|
||
# ── constants ─────────────────────────────────────────────────────────
|
||
|
||
# ISO-3166-1-alpha-2 country codes used by the --region europe filter
|
||
EUROPE_REGION_CODES = {
|
||
"AL", "AD", "AM", "AT", "AZ", "BA", "BE", "BG", "BY", "CH", "CY", "CZ",
|
||
"DE", "DK", "EE", "ES", "FI", "FR", "GB", "GE", "GR", "HR", "HU", "IE",
|
||
"IS", "IT", "LI", "LT", "LU", "LV", "MC", "MD", "ME", "MK", "MT", "NL",
|
||
"NO", "PL", "PT", "RO", "RS", "SE", "SI", "SK", "SM", "TR", "UA", "VA",
|
||
}
|
||
|
||
# How often (in epochs) to rsync generator outputs back during training
|
||
_GENERATOR_SYNC_INTERVAL = 50
|
||
|
||
|
||
# Raised when the user aborts interactive offer selection
|
||
class OfferSelectionAborted(RuntimeError):
|
||
pass
|
||
|
||
|
||
# Raised when the user aborts interactive template selection
|
||
class TemplateSelectionAborted(RuntimeError):
|
||
pass
|
||
|
||
|
||
# ── runner ────────────────────────────────────────────────────────────
|
||
|
||
# Orchestrates the full lifecycle: search offers -> rent GPU -> train -> fetch results -> destroy
|
||
class EphemeralVastRunner:
|
||
|
||
# Pre-defined Vast.ai image templates the user can pick interactively
|
||
TEMPLATE_CATALOG: list[dict[str, str]] = [
|
||
{
|
||
"hash_id": "661d064bbda1f2a133816b6d55da07c3",
|
||
"name": "PyTorch (cuDNN Devel)",
|
||
"image": "pytorch/pytorch:2.5.1-cuda12.4-cudnn9-devel",
|
||
},
|
||
{
|
||
"hash_id": "b9e5a8f3d4c1e7f6a2b0d8c9e5f1a3b7",
|
||
"name": "PyTorch (Latest)",
|
||
"image": "pytorch/pytorch:latest",
|
||
},
|
||
{
|
||
"hash_id": "none",
|
||
"name": "Custom image (no template)",
|
||
"image": "",
|
||
},
|
||
]
|
||
|
||
# Resolve the project root (two levels up from this file)
|
||
def __init__(self, override_path: Path | None) -> None:
|
||
self.project_root = Path(__file__).resolve().parent.parent
|
||
self.config = PipelineConfig.load(self.project_root, override_path)
|
||
self.api = VastApiClient(resolve_api_key())
|
||
self.private_key = resolve_ssh_private_key()
|
||
self.public_key = read_public_key(self.private_key)
|
||
self.exclude_file = self.project_root / "pipeline" / "rsync-excludes.txt"
|
||
|
||
# ── formatting helpers ────────────────────────────────────────────
|
||
|
||
# Format seconds into a compact human-readable string (e.g. "1d2h30m")
|
||
@staticmethod
|
||
def _fmt_duration(seconds: float) -> str:
|
||
s = max(int(seconds), 0)
|
||
days, s = divmod(s, 86400)
|
||
hours, s = divmod(s, 3600)
|
||
minutes, _ = divmod(s, 60)
|
||
parts: list[str] = []
|
||
if days:
|
||
parts.append(f"{days}d")
|
||
if hours or days:
|
||
parts.append(f"{hours}h")
|
||
parts.append(f"{minutes}m")
|
||
return "".join(parts)
|
||
|
||
# Estimate cost from wall-clock time and the offer's $/h rate
|
||
@staticmethod
|
||
def _fmt_cost(elapsed_seconds: float, cost_per_hour: float) -> str:
|
||
return f"${(elapsed_seconds / 3600) * cost_per_hour:.2f}"
|
||
|
||
# ── offer search & selection ──────────────────────────────────────
|
||
|
||
# Build a sort tuple that ranks offers by the chosen mode (performance / price / dlp_per_dollar)
|
||
@staticmethod
|
||
def _offer_sort_key(offer: dict, sort_mode: str) -> tuple:
|
||
price = float(offer.get("dph_total", float("inf")))
|
||
dlperf = float(offer.get("dlperf", float("-inf")))
|
||
reliability = float(offer.get("reliability2", 0.0))
|
||
if sort_mode == "price":
|
||
return (price, -dlperf, -reliability)
|
||
if sort_mode == "dlp_per_dollar":
|
||
per_dollar = dlperf / price if price > 0 else 0.0
|
||
return (-per_dollar, -reliability, price)
|
||
return (-dlperf, price, -reliability)
|
||
|
||
# Match a user-friendly region string (e.g. "europe") against the offer's geolocation code
|
||
@staticmethod
|
||
def _region_matches(offered: str, requested: str) -> bool:
|
||
req = requested.strip().lower()
|
||
if not req:
|
||
return True
|
||
if req == "europe":
|
||
code = offered.rsplit(", ", 1)[-1].strip() if ", " in offered else ""
|
||
return code in EUROPE_REGION_CODES
|
||
return req in offered.lower()
|
||
|
||
# Query Vast.ai for matching offers, optionally filter by region, then sort
|
||
def _search_offers(
|
||
self,
|
||
sort_mode: str | None = None,
|
||
region: str | None = None,
|
||
price_cap: float | None = None,
|
||
) -> list[dict]:
|
||
query = self.config.build_offer_query(price_cap=price_cap)
|
||
offers = self.api.search_offers(query)
|
||
if not offers:
|
||
raise VastApiError("No Vast offers matched the configured filters.")
|
||
if region:
|
||
offers = [o for o in offers if self._region_matches(o.get("geolocation") or "", region)]
|
||
if not offers:
|
||
raise VastApiError(f"No offers matched region filter: {region!r}")
|
||
mode = sort_mode or self.config.search.get("sort_mode", "performance")
|
||
return sorted(offers, key=lambda o: self._offer_sort_key(o, mode))
|
||
|
||
# Print a summary of available regions and their offer counts
|
||
@staticmethod
|
||
def _print_regions(offers: list[dict]) -> None:
|
||
counts: dict[str, int] = {}
|
||
for o in offers:
|
||
region = o.get("geolocation") or "Unknown"
|
||
counts[region] = counts.get(region, 0) + 1
|
||
for region, count in sorted(counts.items(), key=lambda kv: (-kv[1], kv[0])):
|
||
print(f"{region}: {count}")
|
||
|
||
# Render one page of offers with GPU specs, price, and reliability
|
||
def _print_offer_page(self, offers: list[dict], *, page: int, page_size: int) -> None:
|
||
start = page * page_size
|
||
end = min(start + page_size, len(offers))
|
||
print(f"Offers {start + 1}–{end} of {len(offers)}")
|
||
for idx in range(start, end):
|
||
o = offers[idx]
|
||
vram_gb = o.get("gpu_ram", 0) / 1024
|
||
duration = self._fmt_duration(float(o.get("duration") or 0))
|
||
print(
|
||
f" [{idx + 1}] {o.get('gpu_name')} "
|
||
f"gpus={o.get('num_gpus')} "
|
||
f"vram={vram_gb:.0f}GB "
|
||
f"dlperf={o.get('dlperf', 0):.1f} "
|
||
f"$/h={o.get('dph_total', 0):.3f} "
|
||
f"reliability={o.get('reliability2', 0):.3f} "
|
||
f"avail={duration} "
|
||
f"region={o.get('geolocation')} "
|
||
f"id={o.get('id')}"
|
||
)
|
||
|
||
# Paginated interactive prompt — user picks an offer or aborts
|
||
def _choose_offer(self, offers: list[dict]) -> dict:
|
||
page, page_size = 0, 10
|
||
last_page = max((len(offers) - 1) // page_size, 0)
|
||
while True:
|
||
self._print_offer_page(offers, page=page, page_size=page_size)
|
||
raw = input("Select offer [number / n / p / q]: ").strip().lower()
|
||
if raw == "n":
|
||
page = min(page + 1, last_page)
|
||
elif raw == "p":
|
||
page = max(page - 1, 0)
|
||
elif raw == "q":
|
||
raise OfferSelectionAborted
|
||
elif raw.isdigit():
|
||
idx = int(raw) - 1
|
||
if 0 <= idx < len(offers):
|
||
return offers[idx]
|
||
print(" Out of range.")
|
||
else:
|
||
print(" Invalid input.")
|
||
|
||
# ── template selection ────────────────────────────────────────────
|
||
|
||
# Decide which Vast template hash to use: CLI flag, interactive pick, or config default
|
||
def _resolve_template(self, opts: RunOptions) -> str | None:
|
||
if opts.template_hash:
|
||
return opts.template_hash
|
||
if opts.select_template:
|
||
try:
|
||
return self._choose_template()
|
||
except (TemplateSelectionAborted, KeyboardInterrupt):
|
||
print("Template selection aborted, using default image.")
|
||
return None
|
||
return self.config.instance.get("template_hash_id")
|
||
|
||
# Interactive prompt — user picks a Docker image template
|
||
def _choose_template(self) -> str | None:
|
||
print("Available templates:")
|
||
for idx, tpl in enumerate(self.TEMPLATE_CATALOG, 1):
|
||
print(f" [{idx}] {tpl['name']} ({tpl['image'] or 'config image'})")
|
||
while True:
|
||
raw = input("Select template [number / q]: ").strip().lower()
|
||
if raw == "q":
|
||
raise TemplateSelectionAborted
|
||
if raw.isdigit():
|
||
idx = int(raw) - 1
|
||
if 0 <= idx < len(self.TEMPLATE_CATALOG):
|
||
tpl = self.TEMPLATE_CATALOG[idx]
|
||
if tpl["hash_id"] == "none":
|
||
return None
|
||
if tpl["image"]:
|
||
self.config.instance["image"] = tpl["image"]
|
||
print(f" Selected: {tpl['name']}")
|
||
return tpl["hash_id"]
|
||
print(" Invalid input.")
|
||
|
||
# ── instance lifecycle ────────────────────────────────────────────
|
||
|
||
# Build a unique instance label from a prefix, stem, and UTC timestamp
|
||
def _label_for(self, stem: str) -> str:
|
||
ts = datetime.now(UTC).strftime("%Y%m%d-%H%M%S")
|
||
return f"{self.config.instance['label_prefix']}-{stem}-{ts}"
|
||
|
||
# Assemble the JSON payload sent to Vast.ai when renting an instance
|
||
def _build_instance_payload(self, label: str, *, template_hash_id: str | None = None) -> dict[str, Any]:
|
||
inst = self.config.instance
|
||
payload: dict[str, Any] = {
|
||
"image": inst["image"],
|
||
"disk": inst["disk_gb"],
|
||
"runtype": "ssh_direc ssh_proxy",
|
||
"target_state": inst["target_state"],
|
||
"label": label,
|
||
}
|
||
if template_hash_id:
|
||
payload["template_hash_id"] = template_hash_id
|
||
return payload
|
||
|
||
# Poll the API until the instance is running and has SSH credentials
|
||
def _wait_for_instance(self, instance_id: int) -> VastInstance:
|
||
deadline = time.time() + self.config.remote["ssh_timeout_seconds"]
|
||
poll = self.config.remote["poll_interval_seconds"]
|
||
while time.time() < deadline:
|
||
instance = self.api.show_instance(instance_id)
|
||
if instance.actual_status == "running" and instance.ssh_host and instance.ssh_port:
|
||
return instance
|
||
print(f" instance status: {instance.actual_status or 'pending'}...")
|
||
time.sleep(poll)
|
||
raise VastApiError(f"Timed out waiting for instance {instance_id} to be SSH-ready.")
|
||
|
||
# Build an SshTarget from the running instance's connection details
|
||
def _build_target(self, instance: VastInstance) -> SshTarget:
|
||
return SshTarget(
|
||
user=self.config.remote["ssh_user"],
|
||
host=instance.ssh_host or instance.public_ipaddr or "",
|
||
port=int(instance.ssh_port or 22),
|
||
private_key=self.private_key,
|
||
)
|
||
|
||
# ── manifest ──────────────────────────────────────────────────────
|
||
|
||
# Persist the run manifest as a timestamped JSON file under the output directory
|
||
def _write_manifest(self, manifest: RunManifest, local_output_root: Path) -> Path:
|
||
d = local_output_root / "pipeline"
|
||
d.mkdir(parents=True, exist_ok=True)
|
||
ts = manifest.created_at.replace(":", "").replace("-", "")
|
||
path = d / f"{ts}.json"
|
||
with open(path, "w", encoding="utf-8") as fh:
|
||
json.dump(asdict(manifest), fh, indent=2)
|
||
return path
|
||
|
||
# ── remote workspace setup ────────────────────────────────────────
|
||
|
||
# Upload project files, bootstrap the venv, and ensure training data is present
|
||
def _prepare_remote_workspace(
|
||
self, target: SshTarget, *, download_data: bool, send_cropped: bool,
|
||
first_config: Path | None = None,
|
||
) -> None:
|
||
ws = self.config.remote["workspace_dir"]
|
||
poll = self.config.remote["poll_interval_seconds"]
|
||
|
||
print(" Syncing project files...")
|
||
ssh(target, shell_join(["mkdir", "-p", ws]))
|
||
rsync_project(self.project_root, target, ws, self.exclude_file)
|
||
|
||
print(" Bootstrapping remote environment...")
|
||
ssh(target, shell_join(["bash", "-lc", f"cd {ws} && bash pipeline/scripts/bootstrap_env.sh"]))
|
||
|
||
if send_cropped and first_config is not None:
|
||
# Determine which cropped subdirectory to send based on the first config
|
||
first_mod, _ = self._module_for_config(first_config)
|
||
crop_subdir = "generator" if first_mod.startswith("generator/") else "classifier"
|
||
local_zip = self.project_root / f"cropped_{crop_subdir}.zip"
|
||
if not local_zip.exists():
|
||
print(f" Warning: --send-cropped set but cropped_{crop_subdir}.zip not found.")
|
||
print(f" Create it with: zip -r cropped_{crop_subdir}.zip cropped/{crop_subdir}/")
|
||
else:
|
||
size_mb = local_zip.stat().st_size / (1024 * 1024)
|
||
print(f" Sending cropped_{crop_subdir}.zip ({size_mb:.0f} MB)...")
|
||
rsync_file(local_zip, target, ws)
|
||
print(" Unzipping on remote...")
|
||
ssh(target, shell_join(["bash", "-lc", f"cd {ws} && unzip -q -o cropped_{crop_subdir}.zip"]))
|
||
print(f" Pre-cropped {crop_subdir} images ready.")
|
||
|
||
data_ready = self._remote_data_exists(target, f"{ws}/data")
|
||
if download_data or not data_ready:
|
||
# Download the dataset in a detached process so the SSH timeout doesn't kill it
|
||
print(" Downloading dataset from HuggingFace...")
|
||
log = f"{ws}/.pipeline_logs/fetch_ds.log"
|
||
ssh(target, f"mkdir -p {ws}/.pipeline_logs")
|
||
pid = run_detached(
|
||
target,
|
||
f"cd {ws} && source .venv/bin/activate && python3 classifier/tools/fetch_ds.py",
|
||
log,
|
||
)
|
||
reconnects = 0
|
||
# Poll until the fetch process exits, with reconnect logic for SSH drops
|
||
while True:
|
||
try:
|
||
if not is_process_alive(target, pid):
|
||
break
|
||
reconnects = 0
|
||
except RemoteCommandError as exc:
|
||
if "Command failed (255)" not in str(exc) or reconnects >= 10:
|
||
raise
|
||
reconnects += 1
|
||
time.sleep(min(5 * reconnects, 30))
|
||
continue
|
||
time.sleep(poll)
|
||
print(" Dataset ready.")
|
||
|
||
# ── config routing ────────────────────────────────────────────────
|
||
|
||
# Determine which training script and output dir to use based on config location
|
||
def _module_for_config(self, config_path: Path) -> tuple[str, str]:
|
||
try:
|
||
rel = config_path.resolve().relative_to(self.project_root)
|
||
except ValueError:
|
||
rel = config_path
|
||
if rel.parts[0] == "generator":
|
||
return "generator/run.py", "generator/outputs"
|
||
return "classifier/run.py", "classifier/outputs"
|
||
|
||
# ── remote directory checks ───────────────────────────────────────
|
||
|
||
# Check whether a directory exists on the remote host
|
||
def _remote_dir_exists(self, target: SshTarget, path: str) -> bool:
|
||
out = ssh_output(target, f"if [ -d {path} ]; then echo yes; else echo no; fi").strip()
|
||
return out == "yes"
|
||
|
||
# Check that the dataset is populated (not just an empty directory)
|
||
def _remote_data_exists(self, target: SshTarget, data_dir: str) -> bool:
|
||
out = ssh_output(target, f"if [ -d {data_dir}/wiki ]; then echo yes; else echo no; fi").strip()
|
||
return out == "yes"
|
||
|
||
# ── training ─────────────────────────────────────────────────────
|
||
|
||
# Launch the training script inside nohup, stream logs, and handle reconnects
|
||
def _run_training(self, target: SshTarget, config_path: Path, opts: RunOptions) -> None:
|
||
ws = self.config.remote["workspace_dir"]
|
||
config_rel = config_path.resolve().relative_to(self.project_root)
|
||
run_script, output_root = self._module_for_config(config_path)
|
||
is_generator = run_script.startswith("generator/")
|
||
|
||
run_cmd = shell_join(["python3", "-u", run_script, str(config_rel), "--output-root", output_root])
|
||
if opts.use_gpu:
|
||
run_cmd += " --use-gpu"
|
||
|
||
log_dir = f"{ws}/.pipeline_logs"
|
||
log_path = f"{log_dir}/{config_path.stem}.log"
|
||
exit_path = f"{log_dir}/{config_path.stem}.exit"
|
||
|
||
ssh(target, f"mkdir -p {log_dir}")
|
||
|
||
# Run training under nohup; write the exit code to a marker file when done
|
||
pid = run_detached(
|
||
target,
|
||
f"cd {ws} && source .venv/bin/activate && {run_cmd}; echo $? > {exit_path}",
|
||
log_path,
|
||
)
|
||
print(f" Training started (PID {pid})")
|
||
|
||
poll = self.config.remote["poll_interval_seconds"]
|
||
max_reconnects = 10
|
||
reconnects = 0
|
||
next_sync = _GENERATOR_SYNC_INTERVAL if is_generator else None
|
||
local_output_root = self.project_root / output_root
|
||
stop_streaming = threading.Event()
|
||
|
||
# Background thread: tail the remote log and print tqdm-aware output
|
||
def _stream_worker() -> None:
|
||
tail_proc = None
|
||
last_was_progress = False
|
||
try:
|
||
time.sleep(1) # wait for the log file to be created
|
||
tail_proc = tail_log(target, log_path)
|
||
for raw in tail_proc.stdout:
|
||
if stop_streaming.is_set():
|
||
break
|
||
# tqdm uses \r for in-place updates; take only the last segment
|
||
line = raw.rstrip("\r\n").rsplit("\r", 1)[-1].rstrip()
|
||
if not line:
|
||
continue
|
||
is_progress = "it/s]" in line or "%|" in line
|
||
if is_progress:
|
||
print(f"\r {line}", end="", flush=True)
|
||
last_was_progress = True
|
||
else:
|
||
if last_was_progress:
|
||
print(flush=True)
|
||
last_was_progress = False
|
||
print(f" {line}", flush=True)
|
||
except Exception:
|
||
pass
|
||
finally:
|
||
if last_was_progress:
|
||
print(flush=True)
|
||
if tail_proc is not None:
|
||
tail_proc.terminate()
|
||
tail_proc.wait(timeout=5)
|
||
|
||
stream_thread = threading.Thread(target=_stream_worker, daemon=True)
|
||
stream_thread.start()
|
||
|
||
try:
|
||
# Main polling loop: check if the process is alive, periodically sync generator outputs
|
||
while True:
|
||
try:
|
||
alive = is_process_alive(target, pid)
|
||
except RemoteCommandError as exc:
|
||
if "Command failed (255)" not in str(exc) or reconnects >= max_reconnects:
|
||
raise
|
||
reconnects += 1
|
||
print(f"\n SSH dropped (reconnect {reconnects}/{max_reconnects}), "
|
||
"training continues on remote...")
|
||
time.sleep(min(5 * reconnects, 30))
|
||
continue
|
||
|
||
if is_generator and next_sync is not None:
|
||
epoch = self._latest_epoch_in_log(target, log_path)
|
||
if epoch >= next_sync:
|
||
print(f" Syncing generator outputs at epoch {next_sync}...")
|
||
self._fetch_outputs(target, output_root, local_output_root)
|
||
next_sync += _GENERATOR_SYNC_INTERVAL
|
||
|
||
if not alive:
|
||
break
|
||
|
||
reconnects = 0
|
||
time.sleep(poll)
|
||
finally:
|
||
stop_streaming.set()
|
||
stream_thread.join(timeout=10)
|
||
|
||
# Read the exit code marker that the nohup wrapper wrote
|
||
raw_exit = read_remote_file(target, exit_path)
|
||
exit_code = int(raw_exit) if raw_exit is not None else -1
|
||
|
||
if exit_code != 0:
|
||
if self._remote_dir_exists(target, f"{ws}/{output_root}"):
|
||
self._fetch_outputs(target, output_root, local_output_root)
|
||
log_tail = read_remote_file(target, log_path)
|
||
snippet = "\n".join((log_tail or "").splitlines()[-30:])
|
||
raise RemoteCommandError(f"Training failed (exit {exit_code}).\n{snippet}")
|
||
|
||
# Parse the latest epoch number from the training log (used for generator sync scheduling)
|
||
@staticmethod
|
||
def _latest_epoch_in_log(target: SshTarget, log_path: str) -> int:
|
||
try:
|
||
out = ssh_output(target, f"tail -n 200 {log_path} 2>/dev/null || true")
|
||
epochs = [int(m) for m in re.findall(r"\[(\d+)/\d+\]", out)]
|
||
return max(epochs, default=0)
|
||
except RemoteCommandError:
|
||
return 0
|
||
|
||
# Download remote outputs with up to 3 retry attempts (exponential back-off)
|
||
def _fetch_outputs(self, target: SshTarget, remote_output_root: str, local_output_root: Path) -> None:
|
||
remote_path = f"{self.config.remote['workspace_dir']}/{remote_output_root}"
|
||
delay = 5
|
||
for attempt in range(1, 4):
|
||
try:
|
||
fetch_directory(target, remote_path, local_output_root)
|
||
return
|
||
except RemoteCommandError as exc:
|
||
if attempt < 3:
|
||
print(f" Download attempt {attempt} failed: {exc}")
|
||
print(f" Retrying in {delay}s...")
|
||
time.sleep(delay)
|
||
delay *= 2
|
||
else:
|
||
raise
|
||
|
||
# ── public commands ───────────────────────────────────────────────
|
||
|
||
# Full pipeline: validate configs -> find offer -> rent instance -> train -> fetch -> destroy
|
||
def run(self, config_paths: list[Path], opts: RunOptions) -> None:
|
||
resolved = []
|
||
for cp in config_paths:
|
||
cp = (self.project_root / cp).resolve() if not cp.is_absolute() else cp
|
||
if not cp.exists():
|
||
raise FileNotFoundError(f"Config not found: {cp}")
|
||
resolved.append(cp)
|
||
|
||
n = len(resolved)
|
||
_, first_output_root = self._module_for_config(resolved[0])
|
||
local_output_root = self.project_root / first_output_root
|
||
|
||
self.api.ensure_ssh_key(self.public_key)
|
||
offers = self._search_offers(opts.sort_mode, opts.region, opts.price_cap)
|
||
|
||
if opts.list_regions:
|
||
self._print_regions(offers)
|
||
return
|
||
|
||
try:
|
||
offer = self._choose_offer(offers) if opts.select_offer else offers[0]
|
||
except (OfferSelectionAborted, KeyboardInterrupt):
|
||
print("Aborted.")
|
||
return
|
||
|
||
if opts.dry_run:
|
||
print(json.dumps(offer if opts.select_offer else offers[:10], indent=2))
|
||
return
|
||
|
||
label_stem = resolved[0].stem if n == 1 else resolved[0].parent.name
|
||
template_hash_id = self._resolve_template(opts)
|
||
offer_id = int(offer["id"])
|
||
cost_per_hour = float(offer.get("dph_total", 0))
|
||
|
||
manifest = RunManifest(
|
||
created_at=datetime.now(UTC).isoformat(),
|
||
config_paths=[str(cp.relative_to(self.project_root)) for cp in resolved],
|
||
instance_id=None,
|
||
offer_id=offer_id,
|
||
ssh_host=None,
|
||
ssh_port=None,
|
||
status="creating",
|
||
remote_workspace=self.config.remote["workspace_dir"],
|
||
)
|
||
manifest_path = self._write_manifest(manifest, local_output_root)
|
||
|
||
instance_id: int | None = None
|
||
should_destroy = True
|
||
start_time = time.time()
|
||
|
||
try:
|
||
# Build the instance payload and rent the GPU
|
||
payload = self._build_instance_payload(self._label_for(label_stem),
|
||
template_hash_id=template_hash_id)
|
||
print(f"Creating instance from offer {offer_id} (${cost_per_hour:.3f}/h)...")
|
||
instance_id = self.api.create_instance(offer_id, payload)
|
||
self.api.attach_ssh_key(instance_id, self.public_key)
|
||
|
||
# Record the instance ID in the manifest immediately for debugging
|
||
manifest.instance_id = instance_id
|
||
|
||
# Wait for the instance to be SSH-ready, then update the manifest
|
||
print(f"Waiting for instance {instance_id}...")
|
||
instance = self._wait_for_instance(instance_id)
|
||
manifest.ssh_host = instance.ssh_host
|
||
manifest.ssh_port = instance.ssh_port
|
||
manifest.status = instance.actual_status
|
||
self._write_manifest(manifest, local_output_root)
|
||
print(f"Instance ready: {instance.ssh_host}:{instance.ssh_port} "
|
||
f"({instance.gpu_name}, ${instance.dph_total}/h)")
|
||
|
||
target = self._build_target(instance)
|
||
|
||
# Wait for SSH to become available before running any remote commands
|
||
print("Waiting for SSH...")
|
||
wait_for_ssh(
|
||
target,
|
||
timeout_seconds=self.config.remote["ssh_timeout_seconds"],
|
||
poll_interval_seconds=self.config.remote["poll_interval_seconds"],
|
||
)
|
||
self._prepare_remote_workspace(
|
||
target, download_data=opts.download_data, send_cropped=opts.send_cropped,
|
||
first_config=resolved[0],
|
||
)
|
||
|
||
# Run each training config sequentially on the same instance
|
||
for i, config_path in enumerate(resolved, 1):
|
||
print(f"\n[{i}/{n}] Training: {config_path.name}")
|
||
self._run_training(target, config_path, opts)
|
||
print(f"[{i}/{n}] Fetching outputs...")
|
||
_, output_root = self._module_for_config(config_path)
|
||
self._fetch_outputs(target, output_root, self.project_root / output_root)
|
||
|
||
elapsed = time.time() - start_time
|
||
manifest.status = "completed"
|
||
self._write_manifest(manifest, local_output_root)
|
||
print(f"\nAll {n} run(s) completed in {self._fmt_duration(elapsed)} "
|
||
f"(~{self._fmt_cost(elapsed, cost_per_hour)}). Manifest: {manifest_path}")
|
||
|
||
except KeyboardInterrupt:
|
||
elapsed = time.time() - start_time
|
||
manifest.status = "cancelled"
|
||
self._write_manifest(manifest, local_output_root)
|
||
print(f"\nCancelled after {self._fmt_duration(elapsed)} "
|
||
f"(~{self._fmt_cost(elapsed, cost_per_hour)}).")
|
||
|
||
except Exception:
|
||
elapsed = time.time() - start_time
|
||
manifest.status = "failed"
|
||
self._write_manifest(manifest, local_output_root)
|
||
print(f" Failed after {self._fmt_duration(elapsed)} "
|
||
f"(~{self._fmt_cost(elapsed, cost_per_hour)}).")
|
||
if opts.keep_on_failure or self.config.keep_on_failure:
|
||
should_destroy = False
|
||
raise
|
||
|
||
finally:
|
||
if instance_id is not None and should_destroy:
|
||
elapsed = time.time() - start_time
|
||
print(f"Destroying instance {instance_id} "
|
||
f"(total: {self._fmt_duration(elapsed)}, "
|
||
f"~{self._fmt_cost(elapsed, cost_per_hour)})")
|
||
self.api.destroy_instance(instance_id)
|
||
|
||
# CLI subcommand: search and display available GPU offers without renting
|
||
def offers(
|
||
self,
|
||
*,
|
||
sort_mode: str | None,
|
||
region: str | None,
|
||
price_cap: float | None,
|
||
select_offer: bool,
|
||
list_regions: bool,
|
||
limit_output: int,
|
||
) -> None:
|
||
offers = self._search_offers(sort_mode, region, price_cap)
|
||
if list_regions:
|
||
self._print_regions(offers)
|
||
return
|
||
if select_offer:
|
||
try:
|
||
offer = self._choose_offer(offers)
|
||
except (OfferSelectionAborted, KeyboardInterrupt):
|
||
print("Aborted.")
|
||
return
|
||
print(json.dumps(offer, indent=2))
|
||
else:
|
||
print(json.dumps(offers[:limit_output], indent=2))
|
||
|
||
# CLI subcommand: rent an instance and print its connection details (no training)
|
||
def up(self, *, label: str | None, opts: RunOptions | None = None) -> None:
|
||
if opts is None:
|
||
opts = RunOptions()
|
||
self.api.ensure_ssh_key(self.public_key)
|
||
template_hash_id = self._resolve_template(opts)
|
||
offer = self._search_offers()[0]
|
||
offer_id = int(offer["id"])
|
||
payload = self._build_instance_payload(
|
||
label or self._label_for("manual"), template_hash_id=template_hash_id
|
||
)
|
||
instance_id = self.api.create_instance(offer_id, payload)
|
||
self.api.attach_ssh_key(instance_id, self.public_key)
|
||
instance = self._wait_for_instance(instance_id)
|
||
print(json.dumps({
|
||
"offer_id": offer_id,
|
||
"instance_id": instance_id,
|
||
"ssh_host": instance.ssh_host,
|
||
"ssh_port": instance.ssh_port,
|
||
"gpu_name": instance.gpu_name,
|
||
"dph_total": instance.dph_total,
|
||
}, indent=2))
|
||
|
||
# CLI subcommand: print the raw instance status JSON
|
||
def status(self, instance_id: int) -> None:
|
||
print(json.dumps(self.api.show_instance(instance_id).raw, indent=2))
|
||
|
||
# CLI subcommand: destroy a running instance by ID
|
||
def down(self, instance_id: int) -> None:
|
||
self.api.destroy_instance(instance_id)
|
||
print(f"Destroyed instance {instance_id}")
|