Clean state

This commit is contained in:
Johnny Fernandes
2026-04-30 01:25:39 +01:00
commit bb3dfb92d5
266 changed files with 37043 additions and 0 deletions
+817
View File
@@ -0,0 +1,817 @@
import json
import re
import threading
import time
from dataclasses import asdict, dataclass
from datetime import UTC, datetime
from pathlib import Path
from typing import Any
from pipeline.config import (
PipelineConfig,
read_public_key,
resolve_api_key,
resolve_ssh_private_key,
)
from pipeline.remote import (
RemoteCommandError,
SshTarget,
fetch_directory,
is_process_alive,
read_remote_file,
rsync_file,
rsync_project,
run_detached,
shell_join,
ssh,
ssh_output,
tail_log,
wait_for_ssh,
)
from pipeline.vast_api import VastApiClient, VastApiError, VastInstance
# ── data classes ─────────────────────────────────────────────────────
# Snapshot of a pipeline run written to disk for traceability
@dataclass(slots=True)
class RunManifest:
created_at: str
config_paths: list[str]
instance_id: int | None
offer_id: int | None
ssh_host: str | None
ssh_port: int | None
status: str
remote_workspace: str | None
# CLI flags controlling how the pipeline run behaves
@dataclass
class RunOptions:
download_data: bool = False
send_cropped: bool = False
keep_on_failure: bool = False
dry_run: bool = False
select_offer: bool = False
select_template: bool = False
template_hash: str | None = None
sort_mode: str | None = None
region: str | None = None
price_cap: float | None = None
list_regions: bool = False
use_gpu: bool = True
# ── constants ─────────────────────────────────────────────────────────
# ISO-3166-1-alpha-2 country codes used by the --region europe filter
EUROPE_REGION_CODES = {
"AL", "AD", "AM", "AT", "AZ", "BA", "BE", "BG", "BY", "CH", "CY", "CZ",
"DE", "DK", "EE", "ES", "FI", "FR", "GB", "GE", "GR", "HR", "HU", "IE",
"IS", "IT", "LI", "LT", "LU", "LV", "MC", "MD", "ME", "MK", "MT", "NL",
"NO", "PL", "PT", "RO", "RS", "SE", "SI", "SK", "SM", "TR", "UA", "VA",
}
# How often (in epochs) to rsync generator outputs back during training
_GENERATOR_SYNC_INTERVAL = 50
# Raised when the user aborts interactive offer selection
class OfferSelectionAborted(RuntimeError):
pass
# Raised when the user aborts interactive template selection
class TemplateSelectionAborted(RuntimeError):
pass
# ── runner ────────────────────────────────────────────────────────────
# Orchestrates the full lifecycle: search offers -> rent GPU -> train -> fetch results -> destroy
class EphemeralVastRunner:
# Pre-defined Vast.ai image templates the user can pick interactively
TEMPLATE_CATALOG: list[dict[str, str]] = [
{
"hash_id": "661d064bbda1f2a133816b6d55da07c3",
"name": "PyTorch (cuDNN Devel)",
"image": "pytorch/pytorch:2.5.1-cuda12.4-cudnn9-devel",
},
{
"hash_id": "b9e5a8f3d4c1e7f6a2b0d8c9e5f1a3b7",
"name": "PyTorch (Latest)",
"image": "pytorch/pytorch:latest",
},
{
"hash_id": "none",
"name": "Custom image (no template)",
"image": "",
},
]
# Resolve the project root (two levels up from this file)
def __init__(self, override_path: Path | None) -> None:
self.project_root = Path(__file__).resolve().parent.parent
self.config = PipelineConfig.load(self.project_root, override_path)
self.api = VastApiClient(resolve_api_key())
self.private_key = resolve_ssh_private_key()
self.public_key = read_public_key(self.private_key)
self.exclude_file = self.project_root / "pipeline" / "rsync-excludes.txt"
# ── formatting helpers ────────────────────────────────────────────
# Format seconds into a compact human-readable string (e.g. "1d2h30m")
@staticmethod
def _fmt_duration(seconds: float) -> str:
s = max(int(seconds), 0)
days, s = divmod(s, 86400)
hours, s = divmod(s, 3600)
minutes, _ = divmod(s, 60)
parts: list[str] = []
if days:
parts.append(f"{days}d")
if hours or days:
parts.append(f"{hours}h")
parts.append(f"{minutes}m")
return "".join(parts)
# Estimate cost from wall-clock time and the offer's $/h rate
@staticmethod
def _fmt_cost(elapsed_seconds: float, cost_per_hour: float) -> str:
return f"${(elapsed_seconds / 3600) * cost_per_hour:.2f}"
# ── offer search & selection ──────────────────────────────────────
# Build a sort tuple that ranks offers by the chosen mode (performance / price / dlp_per_dollar)
@staticmethod
def _offer_sort_key(offer: dict, sort_mode: str) -> tuple:
price = float(offer.get("dph_total", float("inf")))
dlperf = float(offer.get("dlperf", float("-inf")))
reliability = float(offer.get("reliability2", 0.0))
if sort_mode == "price":
return (price, -dlperf, -reliability)
if sort_mode == "dlp_per_dollar":
per_dollar = dlperf / price if price > 0 else 0.0
return (-per_dollar, -reliability, price)
return (-dlperf, price, -reliability)
# Match a user-friendly region string (e.g. "europe") against the offer's geolocation code
@staticmethod
def _region_matches(offered: str, requested: str) -> bool:
req = requested.strip().lower()
if not req:
return True
if req == "europe":
code = offered.rsplit(", ", 1)[-1].strip() if ", " in offered else ""
return code in EUROPE_REGION_CODES
return req in offered.lower()
# Query Vast.ai for matching offers, optionally filter by region, then sort
def _search_offers(
self,
sort_mode: str | None = None,
region: str | None = None,
price_cap: float | None = None,
) -> list[dict]:
query = self.config.build_offer_query(price_cap=price_cap)
offers = self.api.search_offers(query)
if not offers:
raise VastApiError("No Vast offers matched the configured filters.")
if region:
offers = [o for o in offers if self._region_matches(o.get("geolocation") or "", region)]
if not offers:
raise VastApiError(f"No offers matched region filter: {region!r}")
mode = sort_mode or self.config.search.get("sort_mode", "performance")
return sorted(offers, key=lambda o: self._offer_sort_key(o, mode))
# Print a summary of available regions and their offer counts
@staticmethod
def _print_regions(offers: list[dict]) -> None:
counts: dict[str, int] = {}
for o in offers:
region = o.get("geolocation") or "Unknown"
counts[region] = counts.get(region, 0) + 1
for region, count in sorted(counts.items(), key=lambda kv: (-kv[1], kv[0])):
print(f"{region}: {count}")
# Render one page of offers with GPU specs, price, and reliability
def _print_offer_page(self, offers: list[dict], *, page: int, page_size: int) -> None:
start = page * page_size
end = min(start + page_size, len(offers))
print(f"Offers {start + 1}{end} of {len(offers)}")
for idx in range(start, end):
o = offers[idx]
vram_gb = o.get("gpu_ram", 0) / 1024
duration = self._fmt_duration(float(o.get("duration") or 0))
print(
f" [{idx + 1}] {o.get('gpu_name')} "
f"gpus={o.get('num_gpus')} "
f"vram={vram_gb:.0f}GB "
f"dlperf={o.get('dlperf', 0):.1f} "
f"$/h={o.get('dph_total', 0):.3f} "
f"reliability={o.get('reliability2', 0):.3f} "
f"avail={duration} "
f"region={o.get('geolocation')} "
f"id={o.get('id')}"
)
# Paginated interactive prompt — user picks an offer or aborts
def _choose_offer(self, offers: list[dict]) -> dict:
page, page_size = 0, 10
last_page = max((len(offers) - 1) // page_size, 0)
while True:
self._print_offer_page(offers, page=page, page_size=page_size)
raw = input("Select offer [number / n / p / q]: ").strip().lower()
if raw == "n":
page = min(page + 1, last_page)
elif raw == "p":
page = max(page - 1, 0)
elif raw == "q":
raise OfferSelectionAborted
elif raw.isdigit():
idx = int(raw) - 1
if 0 <= idx < len(offers):
return offers[idx]
print(" Out of range.")
else:
print(" Invalid input.")
# ── template selection ────────────────────────────────────────────
# Decide which Vast template hash to use: CLI flag, interactive pick, or config default
def _resolve_template(self, opts: RunOptions) -> str | None:
if opts.template_hash:
return opts.template_hash
if opts.select_template:
try:
return self._choose_template()
except (TemplateSelectionAborted, KeyboardInterrupt):
print("Template selection aborted, using default image.")
return None
return self.config.instance.get("template_hash_id")
# Interactive prompt — user picks a Docker image template
def _choose_template(self) -> str | None:
print("Available templates:")
for idx, tpl in enumerate(self.TEMPLATE_CATALOG, 1):
print(f" [{idx}] {tpl['name']} ({tpl['image'] or 'config image'})")
while True:
raw = input("Select template [number / q]: ").strip().lower()
if raw == "q":
raise TemplateSelectionAborted
if raw.isdigit():
idx = int(raw) - 1
if 0 <= idx < len(self.TEMPLATE_CATALOG):
tpl = self.TEMPLATE_CATALOG[idx]
if tpl["hash_id"] == "none":
return None
if tpl["image"]:
self.config.instance["image"] = tpl["image"]
print(f" Selected: {tpl['name']}")
return tpl["hash_id"]
print(" Invalid input.")
# ── instance lifecycle ────────────────────────────────────────────
# Build a unique instance label from a prefix, stem, and UTC timestamp
def _label_for(self, stem: str) -> str:
ts = datetime.now(UTC).strftime("%Y%m%d-%H%M%S")
return f"{self.config.instance['label_prefix']}-{stem}-{ts}"
# Assemble the JSON payload sent to Vast.ai when renting an instance
def _build_instance_payload(self, label: str, *, template_hash_id: str | None = None) -> dict[str, Any]:
inst = self.config.instance
payload: dict[str, Any] = {
"image": inst["image"],
"disk": inst["disk_gb"],
"runtype": "ssh_direc ssh_proxy",
"target_state": inst["target_state"],
"label": label,
}
if template_hash_id:
payload["template_hash_id"] = template_hash_id
return payload
# Poll the API until the instance is running and has SSH credentials
def _wait_for_instance(self, instance_id: int) -> VastInstance:
deadline = time.time() + self.config.remote["ssh_timeout_seconds"]
poll = self.config.remote["poll_interval_seconds"]
while time.time() < deadline:
instance = self.api.show_instance(instance_id)
if instance.actual_status == "running" and instance.ssh_host and instance.ssh_port:
return instance
print(f" instance status: {instance.actual_status or 'pending'}...")
time.sleep(poll)
raise VastApiError(f"Timed out waiting for instance {instance_id} to be SSH-ready.")
# Build an SshTarget from the running instance's connection details
def _build_target(self, instance: VastInstance) -> SshTarget:
return SshTarget(
user=self.config.remote["ssh_user"],
host=instance.ssh_host or instance.public_ipaddr or "",
port=int(instance.ssh_port or 22),
private_key=self.private_key,
)
# ── manifest ──────────────────────────────────────────────────────
# Persist the run manifest as a timestamped JSON file under the output directory
def _write_manifest(self, manifest: RunManifest, local_output_root: Path) -> Path:
d = local_output_root / "pipeline"
d.mkdir(parents=True, exist_ok=True)
ts = manifest.created_at.replace(":", "").replace("-", "")
path = d / f"{ts}.json"
with open(path, "w", encoding="utf-8") as fh:
json.dump(asdict(manifest), fh, indent=2)
return path
# ── remote workspace setup ────────────────────────────────────────
# Upload project files, bootstrap the venv, and ensure training data is present
def _prepare_remote_workspace(
self, target: SshTarget, *, download_data: bool, send_cropped: bool,
first_config: Path | None = None,
) -> None:
ws = self.config.remote["workspace_dir"]
poll = self.config.remote["poll_interval_seconds"]
print(" Syncing project files...")
ssh(target, shell_join(["mkdir", "-p", ws]))
rsync_project(self.project_root, target, ws, self.exclude_file)
print(" Bootstrapping remote environment...")
ssh(target, shell_join(["bash", "-lc", f"cd {ws} && bash pipeline/scripts/bootstrap_env.sh"]))
if send_cropped and first_config is not None:
# Determine which cropped subdirectory to send based on the first config
first_mod, _ = self._module_for_config(first_config)
crop_subdir = "generator" if first_mod.startswith("generator/") else "classifier"
local_zip = self.project_root / f"cropped_{crop_subdir}.zip"
if not local_zip.exists():
print(f" Warning: --send-cropped set but cropped_{crop_subdir}.zip not found.")
print(f" Create it with: zip -r cropped_{crop_subdir}.zip cropped/{crop_subdir}/")
else:
size_mb = local_zip.stat().st_size / (1024 * 1024)
print(f" Sending cropped_{crop_subdir}.zip ({size_mb:.0f} MB)...")
rsync_file(local_zip, target, ws)
print(" Unzipping on remote...")
ssh(target, shell_join(["bash", "-lc", f"cd {ws} && unzip -q -o cropped_{crop_subdir}.zip"]))
print(f" Pre-cropped {crop_subdir} images ready.")
data_ready = self._remote_data_exists(target, f"{ws}/data")
if download_data or not data_ready:
# Download the dataset in a detached process so the SSH timeout doesn't kill it
print(" Downloading dataset from HuggingFace...")
log = f"{ws}/.pipeline_logs/fetch_ds.log"
ssh(target, f"mkdir -p {ws}/.pipeline_logs")
pid = run_detached(
target,
f"cd {ws} && source .venv/bin/activate && python3 classifier/tools/fetch_ds.py",
log,
)
reconnects = 0
# Poll until the fetch process exits, with reconnect logic for SSH drops
while True:
try:
if not is_process_alive(target, pid):
break
reconnects = 0
except RemoteCommandError as exc:
if "Command failed (255)" not in str(exc) or reconnects >= 10:
raise
reconnects += 1
time.sleep(min(5 * reconnects, 30))
continue
time.sleep(poll)
print(" Dataset ready.")
# ── config routing ────────────────────────────────────────────────
# Determine which training script and output dir to use based on config location
def _module_for_config(self, config_path: Path) -> tuple[str, str]:
try:
rel = config_path.resolve().relative_to(self.project_root)
except ValueError:
rel = config_path
if rel.parts[0] == "generator":
return "generator/run.py", "generator/outputs"
return "classifier/run.py", "classifier/outputs"
# Merge two dicts recursively (override wins on leaf keys)
@staticmethod
def _deep_merge_dicts(base: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]:
result = base.copy()
for key, value in override.items():
if key in result and isinstance(result[key], dict) and isinstance(value, dict):
result[key] = EphemeralVastRunner._deep_merge_dicts(result[key], value)
else:
result[key] = value
return result
# Recursively resolve "extends" references in JSON configs with cycle detection
def _load_config_with_extends(self, config_path: Path, seen: set[Path] | None = None) -> dict[str, Any]:
if seen is None:
seen = set()
resolved = config_path.resolve()
if resolved in seen:
raise ValueError(f"Circular config inheritance detected at: {config_path}")
seen.add(resolved)
with open(config_path, encoding="utf-8") as fh:
cfg = json.load(fh)
base_ref = cfg.pop("extends", None)
if not base_ref:
seen.remove(resolved)
return cfg
base_path = (config_path.parent / base_ref).resolve()
base_cfg = self._load_config_with_extends(base_path, seen=seen)
seen.remove(resolved)
return self._deep_merge_dicts(base_cfg, cfg)
# Return a stable signature used to detect duplicate training configs
def _normalized_config_signature(self, config_path: Path) -> str:
run_script, _ = self._module_for_config(config_path)
# Generator configs do not currently use shared/extends inheritance.
if run_script.startswith("generator/"):
with open(config_path, encoding="utf-8") as fh:
cfg = json.load(fh)
else:
cfg = self._load_config_with_extends(config_path)
shared_path = config_path.parent.parent / "shared.json"
if shared_path.exists():
with open(shared_path, encoding="utf-8") as fh:
shared_cfg = json.load(fh)
cfg = self._deep_merge_dicts(shared_cfg, cfg)
# run_name should not influence whether two configs are equivalent to train.
cfg.pop("run_name", None)
return json.dumps(cfg, sort_keys=True, separators=(",", ":"))
# Detect configs that are pure extends (only run_name + extends, no new training settings)
# These are pointers to an already-trained experiment and should be skipped
def _is_pure_extend(self, config_path: Path) -> str | None:
with open(config_path, encoding="utf-8") as fh:
raw = json.load(fh)
if "extends" not in raw:
return None
non_meta = {k for k in raw if k not in ("run_name", "extends")}
if non_meta:
return None
base_path = (config_path.parent / raw["extends"]).resolve()
return str(base_path.relative_to(self.project_root))
def _dedupe_training_configs(self, config_paths: list[Path]) -> list[Path]:
seen: dict[str, Path] = {}
deduped: list[Path] = []
for cp in config_paths:
pure_extends = self._is_pure_extend(cp)
if pure_extends is not None:
print(f"Skipping {cp.name} (pure extend of {pure_extends})")
continue
sig = self._normalized_config_signature(cp)
if sig in seen:
first = seen[sig]
print(f"Skipping duplicate config {cp.name} (same training settings as {first.name})")
continue
seen[sig] = cp
deduped.append(cp)
return deduped
# ── remote directory checks ───────────────────────────────────────
# Check whether a directory exists on the remote host
def _remote_dir_exists(self, target: SshTarget, path: str) -> bool:
out = ssh_output(target, f"if [ -d {path} ]; then echo yes; else echo no; fi").strip()
return out == "yes"
# Check that the dataset is populated (not just an empty directory)
def _remote_data_exists(self, target: SshTarget, data_dir: str) -> bool:
out = ssh_output(target, f"if [ -d {data_dir}/wiki ]; then echo yes; else echo no; fi").strip()
return out == "yes"
# ── training ─────────────────────────────────────────────────────
# Launch the training script inside nohup, stream logs, and handle reconnects
def _run_training(self, target: SshTarget, config_path: Path, opts: RunOptions) -> None:
ws = self.config.remote["workspace_dir"]
config_rel = config_path.resolve().relative_to(self.project_root)
run_script, output_root = self._module_for_config(config_path)
is_generator = run_script.startswith("generator/")
run_cmd = shell_join(["python3", "-u", run_script, str(config_rel), "--output-root", output_root])
if opts.use_gpu:
run_cmd += " --use-gpu"
log_dir = f"{ws}/.pipeline_logs"
log_path = f"{log_dir}/{config_path.stem}.log"
exit_path = f"{log_dir}/{config_path.stem}.exit"
ssh(target, f"mkdir -p {log_dir}")
# Run training under nohup; write the exit code to a marker file when done
pid = run_detached(
target,
f"cd {ws} && source .venv/bin/activate && {run_cmd}; echo $? > {exit_path}",
log_path,
)
print(f" Training started (PID {pid})")
poll = self.config.remote["poll_interval_seconds"]
max_reconnects = 10
reconnects = 0
next_sync = _GENERATOR_SYNC_INTERVAL if is_generator else None
local_output_root = self.project_root / output_root
stop_streaming = threading.Event()
# Background thread: tail the remote log and print tqdm-aware output
def _stream_worker() -> None:
tail_proc = None
last_was_progress = False
try:
time.sleep(1) # wait for the log file to be created
tail_proc = tail_log(target, log_path)
for raw in tail_proc.stdout:
if stop_streaming.is_set():
break
# tqdm uses \r for in-place updates; take only the last segment
line = raw.rstrip("\r\n").rsplit("\r", 1)[-1].rstrip()
if not line:
continue
is_progress = "it/s]" in line or "%|" in line
if is_progress:
print(f"\r {line}", end="", flush=True)
last_was_progress = True
else:
if last_was_progress:
print(flush=True)
last_was_progress = False
print(f" {line}", flush=True)
except Exception:
pass
finally:
if last_was_progress:
print(flush=True)
if tail_proc is not None:
tail_proc.terminate()
tail_proc.wait(timeout=5)
stream_thread = threading.Thread(target=_stream_worker, daemon=True)
stream_thread.start()
try:
# Main polling loop: check if the process is alive, periodically sync generator outputs
while True:
try:
alive = is_process_alive(target, pid)
except RemoteCommandError as exc:
if "Command failed (255)" not in str(exc) or reconnects >= max_reconnects:
raise
reconnects += 1
print(f"\n SSH dropped (reconnect {reconnects}/{max_reconnects}), "
"training continues on remote...")
time.sleep(min(5 * reconnects, 30))
continue
if is_generator and next_sync is not None:
epoch = self._latest_epoch_in_log(target, log_path)
if epoch >= next_sync:
print(f" Syncing generator outputs at epoch {next_sync}...")
self._fetch_outputs(target, output_root, local_output_root)
next_sync += _GENERATOR_SYNC_INTERVAL
if not alive:
break
reconnects = 0
time.sleep(poll)
finally:
stop_streaming.set()
stream_thread.join(timeout=10)
# Read the exit code marker that the nohup wrapper wrote
raw_exit = read_remote_file(target, exit_path)
exit_code = int(raw_exit) if raw_exit is not None else -1
if exit_code != 0:
if self._remote_dir_exists(target, f"{ws}/{output_root}"):
self._fetch_outputs(target, output_root, local_output_root)
log_tail = read_remote_file(target, log_path)
snippet = "\n".join((log_tail or "").splitlines()[-30:])
raise RemoteCommandError(f"Training failed (exit {exit_code}).\n{snippet}")
# Parse the latest epoch number from the training log (used for generator sync scheduling)
@staticmethod
def _latest_epoch_in_log(target: SshTarget, log_path: str) -> int:
try:
out = ssh_output(target, f"tail -n 200 {log_path} 2>/dev/null || true")
epochs = [int(m) for m in re.findall(r"\[(\d+)/\d+\]", out)]
return max(epochs, default=0)
except RemoteCommandError:
return 0
# Download remote outputs with up to 3 retry attempts (exponential back-off)
def _fetch_outputs(self, target: SshTarget, remote_output_root: str, local_output_root: Path) -> None:
remote_path = f"{self.config.remote['workspace_dir']}/{remote_output_root}"
delay = 5
for attempt in range(1, 4):
try:
fetch_directory(target, remote_path, local_output_root)
return
except RemoteCommandError as exc:
if attempt < 3:
print(f" Download attempt {attempt} failed: {exc}")
print(f" Retrying in {delay}s...")
time.sleep(delay)
delay *= 2
else:
raise
# ── public commands ───────────────────────────────────────────────
# Full pipeline: validate configs -> find offer -> rent instance -> train -> fetch -> destroy
def run(self, config_paths: list[Path], opts: RunOptions) -> None:
resolved = []
for cp in config_paths:
cp = (self.project_root / cp).resolve() if not cp.is_absolute() else cp
if not cp.exists():
raise FileNotFoundError(f"Config not found: {cp}")
resolved.append(cp)
resolved = self._dedupe_training_configs(resolved)
# Abort early if all configs were duplicates
if not resolved:
raise ValueError("No unique configs to run after deduplication.")
n = len(resolved)
_, first_output_root = self._module_for_config(resolved[0])
local_output_root = self.project_root / first_output_root
self.api.ensure_ssh_key(self.public_key)
offers = self._search_offers(opts.sort_mode, opts.region, opts.price_cap)
if opts.list_regions:
self._print_regions(offers)
return
try:
offer = self._choose_offer(offers) if opts.select_offer else offers[0]
except (OfferSelectionAborted, KeyboardInterrupt):
print("Aborted.")
return
if opts.dry_run:
print(json.dumps(offer if opts.select_offer else offers[:10], indent=2))
return
label_stem = resolved[0].stem if n == 1 else resolved[0].parent.name
template_hash_id = self._resolve_template(opts)
offer_id = int(offer["id"])
cost_per_hour = float(offer.get("dph_total", 0))
manifest = RunManifest(
created_at=datetime.now(UTC).isoformat(),
config_paths=[str(cp.relative_to(self.project_root)) for cp in resolved],
instance_id=None,
offer_id=offer_id,
ssh_host=None,
ssh_port=None,
status="creating",
remote_workspace=self.config.remote["workspace_dir"],
)
manifest_path = self._write_manifest(manifest, local_output_root)
instance_id: int | None = None
should_destroy = True
start_time = time.time()
try:
# Build the instance payload and rent the GPU
payload = self._build_instance_payload(self._label_for(label_stem),
template_hash_id=template_hash_id)
print(f"Creating instance from offer {offer_id} (${cost_per_hour:.3f}/h)...")
instance_id = self.api.create_instance(offer_id, payload)
self.api.attach_ssh_key(instance_id, self.public_key)
# Record the instance ID in the manifest immediately for debugging
manifest.instance_id = instance_id
# Wait for the instance to be SSH-ready, then update the manifest
print(f"Waiting for instance {instance_id}...")
instance = self._wait_for_instance(instance_id)
manifest.ssh_host = instance.ssh_host
manifest.ssh_port = instance.ssh_port
manifest.status = instance.actual_status
self._write_manifest(manifest, local_output_root)
print(f"Instance ready: {instance.ssh_host}:{instance.ssh_port} "
f"({instance.gpu_name}, ${instance.dph_total}/h)")
target = self._build_target(instance)
# Wait for SSH to become available before running any remote commands
print("Waiting for SSH...")
wait_for_ssh(
target,
timeout_seconds=self.config.remote["ssh_timeout_seconds"],
poll_interval_seconds=self.config.remote["poll_interval_seconds"],
)
self._prepare_remote_workspace(
target, download_data=opts.download_data, send_cropped=opts.send_cropped,
first_config=resolved[0],
)
# Run each training config sequentially on the same instance
for i, config_path in enumerate(resolved, 1):
print(f"\n[{i}/{n}] Training: {config_path.name}")
self._run_training(target, config_path, opts)
print(f"[{i}/{n}] Fetching outputs...")
_, output_root = self._module_for_config(config_path)
self._fetch_outputs(target, output_root, self.project_root / output_root)
elapsed = time.time() - start_time
manifest.status = "completed"
self._write_manifest(manifest, local_output_root)
print(f"\nAll {n} run(s) completed in {self._fmt_duration(elapsed)} "
f"(~{self._fmt_cost(elapsed, cost_per_hour)}). Manifest: {manifest_path}")
except KeyboardInterrupt:
elapsed = time.time() - start_time
manifest.status = "cancelled"
self._write_manifest(manifest, local_output_root)
print(f"\nCancelled after {self._fmt_duration(elapsed)} "
f"(~{self._fmt_cost(elapsed, cost_per_hour)}).")
except Exception:
elapsed = time.time() - start_time
manifest.status = "failed"
self._write_manifest(manifest, local_output_root)
print(f" Failed after {self._fmt_duration(elapsed)} "
f"(~{self._fmt_cost(elapsed, cost_per_hour)}).")
if opts.keep_on_failure or self.config.keep_on_failure:
should_destroy = False
raise
finally:
if instance_id is not None and should_destroy:
elapsed = time.time() - start_time
print(f"Destroying instance {instance_id} "
f"(total: {self._fmt_duration(elapsed)}, "
f"~{self._fmt_cost(elapsed, cost_per_hour)})")
self.api.destroy_instance(instance_id)
# CLI subcommand: search and display available GPU offers without renting
def offers(
self,
*,
sort_mode: str | None,
region: str | None,
price_cap: float | None,
select_offer: bool,
list_regions: bool,
limit_output: int,
) -> None:
offers = self._search_offers(sort_mode, region, price_cap)
if list_regions:
self._print_regions(offers)
return
if select_offer:
try:
offer = self._choose_offer(offers)
except (OfferSelectionAborted, KeyboardInterrupt):
print("Aborted.")
return
print(json.dumps(offer, indent=2))
else:
print(json.dumps(offers[:limit_output], indent=2))
# CLI subcommand: rent an instance and print its connection details (no training)
def up(self, *, label: str | None, opts: RunOptions | None = None) -> None:
if opts is None:
opts = RunOptions()
self.api.ensure_ssh_key(self.public_key)
template_hash_id = self._resolve_template(opts)
offer = self._search_offers()[0]
offer_id = int(offer["id"])
payload = self._build_instance_payload(
label or self._label_for("manual"), template_hash_id=template_hash_id
)
instance_id = self.api.create_instance(offer_id, payload)
self.api.attach_ssh_key(instance_id, self.public_key)
instance = self._wait_for_instance(instance_id)
print(json.dumps({
"offer_id": offer_id,
"instance_id": instance_id,
"ssh_host": instance.ssh_host,
"ssh_port": instance.ssh_port,
"gpu_name": instance.gpu_name,
"dph_total": instance.dph_total,
}, indent=2))
# CLI subcommand: print the raw instance status JSON
def status(self, instance_id: int) -> None:
print(json.dumps(self.api.show_instance(instance_id).raw, indent=2))
# CLI subcommand: destroy a running instance by ID
def down(self, instance_id: int) -> None:
self.api.destroy_instance(instance_id)
print(f"Destroyed instance {instance_id}")