Files
DRL_PROJ/pipeline/orchestrator.py
T
2026-04-30 03:33:42 +01:00

734 lines
32 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import json
import re
import threading
import time
from dataclasses import asdict, dataclass
from datetime import UTC, datetime
from pathlib import Path
from typing import Any
from pipeline.config import (
PipelineConfig,
read_public_key,
resolve_api_key,
resolve_ssh_private_key,
)
from pipeline.remote import (
RemoteCommandError,
SshTarget,
fetch_directory,
is_process_alive,
read_remote_file,
rsync_file,
rsync_project,
run_detached,
shell_join,
ssh,
ssh_output,
tail_log,
wait_for_ssh,
)
from pipeline.vast_api import VastApiClient, VastApiError, VastInstance
# ── data classes ─────────────────────────────────────────────────────
# Snapshot of a pipeline run written to disk for traceability
@dataclass(slots=True)
class RunManifest:
created_at: str
config_paths: list[str]
instance_id: int | None
offer_id: int | None
ssh_host: str | None
ssh_port: int | None
status: str
remote_workspace: str | None
# CLI flags controlling how the pipeline run behaves
@dataclass
class RunOptions:
download_data: bool = False
send_cropped: bool = False
keep_on_failure: bool = False
dry_run: bool = False
select_offer: bool = False
select_template: bool = False
template_hash: str | None = None
sort_mode: str | None = None
region: str | None = None
price_cap: float | None = None
list_regions: bool = False
use_gpu: bool = True
# ── constants ─────────────────────────────────────────────────────────
# ISO-3166-1-alpha-2 country codes used by the --region europe filter
EUROPE_REGION_CODES = {
"AL", "AD", "AM", "AT", "AZ", "BA", "BE", "BG", "BY", "CH", "CY", "CZ",
"DE", "DK", "EE", "ES", "FI", "FR", "GB", "GE", "GR", "HR", "HU", "IE",
"IS", "IT", "LI", "LT", "LU", "LV", "MC", "MD", "ME", "MK", "MT", "NL",
"NO", "PL", "PT", "RO", "RS", "SE", "SI", "SK", "SM", "TR", "UA", "VA",
}
# How often (in epochs) to rsync generator outputs back during training
_GENERATOR_SYNC_INTERVAL = 50
# Raised when the user aborts interactive offer selection
class OfferSelectionAborted(RuntimeError):
pass
# Raised when the user aborts interactive template selection
class TemplateSelectionAborted(RuntimeError):
pass
# ── runner ────────────────────────────────────────────────────────────
# Orchestrates the full lifecycle: search offers -> rent GPU -> train -> fetch results -> destroy
class EphemeralVastRunner:
# Pre-defined Vast.ai image templates the user can pick interactively
TEMPLATE_CATALOG: list[dict[str, str]] = [
{
"hash_id": "661d064bbda1f2a133816b6d55da07c3",
"name": "PyTorch (cuDNN Devel)",
"image": "pytorch/pytorch:2.5.1-cuda12.4-cudnn9-devel",
},
{
"hash_id": "b9e5a8f3d4c1e7f6a2b0d8c9e5f1a3b7",
"name": "PyTorch (Latest)",
"image": "pytorch/pytorch:latest",
},
{
"hash_id": "none",
"name": "Custom image (no template)",
"image": "",
},
]
# Resolve the project root (two levels up from this file)
def __init__(self, override_path: Path | None) -> None:
self.project_root = Path(__file__).resolve().parent.parent
self.config = PipelineConfig.load(self.project_root, override_path)
self.api = VastApiClient(resolve_api_key())
self.private_key = resolve_ssh_private_key()
self.public_key = read_public_key(self.private_key)
self.exclude_file = self.project_root / "pipeline" / "rsync-excludes.txt"
# ── formatting helpers ────────────────────────────────────────────
# Format seconds into a compact human-readable string (e.g. "1d2h30m")
@staticmethod
def _fmt_duration(seconds: float) -> str:
s = max(int(seconds), 0)
days, s = divmod(s, 86400)
hours, s = divmod(s, 3600)
minutes, _ = divmod(s, 60)
parts: list[str] = []
if days:
parts.append(f"{days}d")
if hours or days:
parts.append(f"{hours}h")
parts.append(f"{minutes}m")
return "".join(parts)
# Estimate cost from wall-clock time and the offer's $/h rate
@staticmethod
def _fmt_cost(elapsed_seconds: float, cost_per_hour: float) -> str:
return f"${(elapsed_seconds / 3600) * cost_per_hour:.2f}"
# ── offer search & selection ──────────────────────────────────────
# Build a sort tuple that ranks offers by the chosen mode (performance / price / dlp_per_dollar)
@staticmethod
def _offer_sort_key(offer: dict, sort_mode: str) -> tuple:
price = float(offer.get("dph_total", float("inf")))
dlperf = float(offer.get("dlperf", float("-inf")))
reliability = float(offer.get("reliability2", 0.0))
if sort_mode == "price":
return (price, -dlperf, -reliability)
if sort_mode == "dlp_per_dollar":
per_dollar = dlperf / price if price > 0 else 0.0
return (-per_dollar, -reliability, price)
return (-dlperf, price, -reliability)
# Match a user-friendly region string (e.g. "europe") against the offer's geolocation code
@staticmethod
def _region_matches(offered: str, requested: str) -> bool:
req = requested.strip().lower()
if not req:
return True
if req == "europe":
code = offered.rsplit(", ", 1)[-1].strip() if ", " in offered else ""
return code in EUROPE_REGION_CODES
return req in offered.lower()
# Query Vast.ai for matching offers, optionally filter by region, then sort
def _search_offers(
self,
sort_mode: str | None = None,
region: str | None = None,
price_cap: float | None = None,
) -> list[dict]:
query = self.config.build_offer_query(price_cap=price_cap)
offers = self.api.search_offers(query)
if not offers:
raise VastApiError("No Vast offers matched the configured filters.")
if region:
offers = [o for o in offers if self._region_matches(o.get("geolocation") or "", region)]
if not offers:
raise VastApiError(f"No offers matched region filter: {region!r}")
mode = sort_mode or self.config.search.get("sort_mode", "performance")
return sorted(offers, key=lambda o: self._offer_sort_key(o, mode))
# Print a summary of available regions and their offer counts
@staticmethod
def _print_regions(offers: list[dict]) -> None:
counts: dict[str, int] = {}
for o in offers:
region = o.get("geolocation") or "Unknown"
counts[region] = counts.get(region, 0) + 1
for region, count in sorted(counts.items(), key=lambda kv: (-kv[1], kv[0])):
print(f"{region}: {count}")
# Render one page of offers with GPU specs, price, and reliability
def _print_offer_page(self, offers: list[dict], *, page: int, page_size: int) -> None:
start = page * page_size
end = min(start + page_size, len(offers))
print(f"Offers {start + 1}{end} of {len(offers)}")
for idx in range(start, end):
o = offers[idx]
vram_gb = o.get("gpu_ram", 0) / 1024
duration = self._fmt_duration(float(o.get("duration") or 0))
print(
f" [{idx + 1}] {o.get('gpu_name')} "
f"gpus={o.get('num_gpus')} "
f"vram={vram_gb:.0f}GB "
f"dlperf={o.get('dlperf', 0):.1f} "
f"$/h={o.get('dph_total', 0):.3f} "
f"reliability={o.get('reliability2', 0):.3f} "
f"avail={duration} "
f"region={o.get('geolocation')} "
f"id={o.get('id')}"
)
# Paginated interactive prompt — user picks an offer or aborts
def _choose_offer(self, offers: list[dict]) -> dict:
page, page_size = 0, 10
last_page = max((len(offers) - 1) // page_size, 0)
while True:
self._print_offer_page(offers, page=page, page_size=page_size)
raw = input("Select offer [number / n / p / q]: ").strip().lower()
if raw == "n":
page = min(page + 1, last_page)
elif raw == "p":
page = max(page - 1, 0)
elif raw == "q":
raise OfferSelectionAborted
elif raw.isdigit():
idx = int(raw) - 1
if 0 <= idx < len(offers):
return offers[idx]
print(" Out of range.")
else:
print(" Invalid input.")
# ── template selection ────────────────────────────────────────────
# Decide which Vast template hash to use: CLI flag, interactive pick, or config default
def _resolve_template(self, opts: RunOptions) -> str | None:
if opts.template_hash:
return opts.template_hash
if opts.select_template:
try:
return self._choose_template()
except (TemplateSelectionAborted, KeyboardInterrupt):
print("Template selection aborted, using default image.")
return None
return self.config.instance.get("template_hash_id")
# Interactive prompt — user picks a Docker image template
def _choose_template(self) -> str | None:
print("Available templates:")
for idx, tpl in enumerate(self.TEMPLATE_CATALOG, 1):
print(f" [{idx}] {tpl['name']} ({tpl['image'] or 'config image'})")
while True:
raw = input("Select template [number / q]: ").strip().lower()
if raw == "q":
raise TemplateSelectionAborted
if raw.isdigit():
idx = int(raw) - 1
if 0 <= idx < len(self.TEMPLATE_CATALOG):
tpl = self.TEMPLATE_CATALOG[idx]
if tpl["hash_id"] == "none":
return None
if tpl["image"]:
self.config.instance["image"] = tpl["image"]
print(f" Selected: {tpl['name']}")
return tpl["hash_id"]
print(" Invalid input.")
# ── instance lifecycle ────────────────────────────────────────────
# Build a unique instance label from a prefix, stem, and UTC timestamp
def _label_for(self, stem: str) -> str:
ts = datetime.now(UTC).strftime("%Y%m%d-%H%M%S")
return f"{self.config.instance['label_prefix']}-{stem}-{ts}"
# Assemble the JSON payload sent to Vast.ai when renting an instance
def _build_instance_payload(self, label: str, *, template_hash_id: str | None = None) -> dict[str, Any]:
inst = self.config.instance
payload: dict[str, Any] = {
"image": inst["image"],
"disk": inst["disk_gb"],
"runtype": "ssh_direc ssh_proxy",
"target_state": inst["target_state"],
"label": label,
}
if template_hash_id:
payload["template_hash_id"] = template_hash_id
return payload
# Poll the API until the instance is running and has SSH credentials
def _wait_for_instance(self, instance_id: int) -> VastInstance:
deadline = time.time() + self.config.remote["ssh_timeout_seconds"]
poll = self.config.remote["poll_interval_seconds"]
while time.time() < deadline:
instance = self.api.show_instance(instance_id)
if instance.actual_status == "running" and instance.ssh_host and instance.ssh_port:
return instance
print(f" instance status: {instance.actual_status or 'pending'}...")
time.sleep(poll)
raise VastApiError(f"Timed out waiting for instance {instance_id} to be SSH-ready.")
# Build an SshTarget from the running instance's connection details
def _build_target(self, instance: VastInstance) -> SshTarget:
return SshTarget(
user=self.config.remote["ssh_user"],
host=instance.ssh_host or instance.public_ipaddr or "",
port=int(instance.ssh_port or 22),
private_key=self.private_key,
)
# ── manifest ──────────────────────────────────────────────────────
# Persist the run manifest as a timestamped JSON file under the output directory
def _write_manifest(self, manifest: RunManifest, local_output_root: Path) -> Path:
d = local_output_root / "pipeline"
d.mkdir(parents=True, exist_ok=True)
ts = manifest.created_at.replace(":", "").replace("-", "")
path = d / f"{ts}.json"
with open(path, "w", encoding="utf-8") as fh:
json.dump(asdict(manifest), fh, indent=2)
return path
# ── remote workspace setup ────────────────────────────────────────
# Upload project files, bootstrap the venv, and ensure training data is present
def _prepare_remote_workspace(
self, target: SshTarget, *, download_data: bool, send_cropped: bool,
first_config: Path | None = None,
) -> None:
ws = self.config.remote["workspace_dir"]
poll = self.config.remote["poll_interval_seconds"]
print(" Syncing project files...")
ssh(target, shell_join(["mkdir", "-p", ws]))
rsync_project(self.project_root, target, ws, self.exclude_file)
print(" Bootstrapping remote environment...")
ssh(target, shell_join(["bash", "-lc", f"cd {ws} && bash pipeline/scripts/bootstrap_env.sh"]))
if send_cropped and first_config is not None:
# Determine which cropped subdirectory to send based on the first config
first_mod, _ = self._module_for_config(first_config)
crop_subdir = "generator" if first_mod.startswith("generator/") else "classifier"
local_zip = self.project_root / f"cropped_{crop_subdir}.zip"
if not local_zip.exists():
print(f" Warning: --send-cropped set but cropped_{crop_subdir}.zip not found.")
print(f" Create it with: zip -r cropped_{crop_subdir}.zip cropped/{crop_subdir}/")
else:
size_mb = local_zip.stat().st_size / (1024 * 1024)
print(f" Sending cropped_{crop_subdir}.zip ({size_mb:.0f} MB)...")
rsync_file(local_zip, target, ws)
print(" Unzipping on remote...")
ssh(target, shell_join(["bash", "-lc", f"cd {ws} && unzip -q -o cropped_{crop_subdir}.zip"]))
print(f" Pre-cropped {crop_subdir} images ready.")
data_ready = self._remote_data_exists(target, f"{ws}/data")
if download_data or not data_ready:
# Download the dataset in a detached process so the SSH timeout doesn't kill it
print(" Downloading dataset from HuggingFace...")
log = f"{ws}/.pipeline_logs/fetch_ds.log"
ssh(target, f"mkdir -p {ws}/.pipeline_logs")
pid = run_detached(
target,
f"cd {ws} && source .venv/bin/activate && python3 classifier/tools/fetch_ds.py",
log,
)
reconnects = 0
# Poll until the fetch process exits, with reconnect logic for SSH drops
while True:
try:
if not is_process_alive(target, pid):
break
reconnects = 0
except RemoteCommandError as exc:
if "Command failed (255)" not in str(exc) or reconnects >= 10:
raise
reconnects += 1
time.sleep(min(5 * reconnects, 30))
continue
time.sleep(poll)
print(" Dataset ready.")
# ── config routing ────────────────────────────────────────────────
# Determine which training script and output dir to use based on config location
def _module_for_config(self, config_path: Path) -> tuple[str, str]:
try:
rel = config_path.resolve().relative_to(self.project_root)
except ValueError:
rel = config_path
if rel.parts[0] == "generator":
return "generator/run.py", "generator/outputs"
return "classifier/run.py", "classifier/outputs"
# ── remote directory checks ───────────────────────────────────────
# Check whether a directory exists on the remote host
def _remote_dir_exists(self, target: SshTarget, path: str) -> bool:
out = ssh_output(target, f"if [ -d {path} ]; then echo yes; else echo no; fi").strip()
return out == "yes"
# Check that the dataset is populated (not just an empty directory)
def _remote_data_exists(self, target: SshTarget, data_dir: str) -> bool:
out = ssh_output(target, f"if [ -d {data_dir}/wiki ]; then echo yes; else echo no; fi").strip()
return out == "yes"
# ── training ─────────────────────────────────────────────────────
# Launch the training script inside nohup, stream logs, and handle reconnects
def _run_training(self, target: SshTarget, config_path: Path, opts: RunOptions) -> None:
ws = self.config.remote["workspace_dir"]
config_rel = config_path.resolve().relative_to(self.project_root)
run_script, output_root = self._module_for_config(config_path)
is_generator = run_script.startswith("generator/")
run_cmd = shell_join(["python3", "-u", run_script, str(config_rel), "--output-root", output_root])
if opts.use_gpu:
run_cmd += " --use-gpu"
log_dir = f"{ws}/.pipeline_logs"
log_path = f"{log_dir}/{config_path.stem}.log"
exit_path = f"{log_dir}/{config_path.stem}.exit"
ssh(target, f"mkdir -p {log_dir}")
# Run training under nohup; write the exit code to a marker file when done
pid = run_detached(
target,
f"cd {ws} && source .venv/bin/activate && {run_cmd}; echo $? > {exit_path}",
log_path,
)
print(f" Training started (PID {pid})")
poll = self.config.remote["poll_interval_seconds"]
max_reconnects = 10
reconnects = 0
next_sync = _GENERATOR_SYNC_INTERVAL if is_generator else None
local_output_root = self.project_root / output_root
stop_streaming = threading.Event()
# Background thread: tail the remote log and print tqdm-aware output
def _stream_worker() -> None:
tail_proc = None
last_was_progress = False
try:
time.sleep(1) # wait for the log file to be created
tail_proc = tail_log(target, log_path)
for raw in tail_proc.stdout:
if stop_streaming.is_set():
break
# tqdm uses \r for in-place updates; take only the last segment
line = raw.rstrip("\r\n").rsplit("\r", 1)[-1].rstrip()
if not line:
continue
is_progress = "it/s]" in line or "%|" in line
if is_progress:
print(f"\r {line}", end="", flush=True)
last_was_progress = True
else:
if last_was_progress:
print(flush=True)
last_was_progress = False
print(f" {line}", flush=True)
except Exception:
pass
finally:
if last_was_progress:
print(flush=True)
if tail_proc is not None:
tail_proc.terminate()
tail_proc.wait(timeout=5)
stream_thread = threading.Thread(target=_stream_worker, daemon=True)
stream_thread.start()
try:
# Main polling loop: check if the process is alive, periodically sync generator outputs
while True:
try:
alive = is_process_alive(target, pid)
except RemoteCommandError as exc:
if "Command failed (255)" not in str(exc) or reconnects >= max_reconnects:
raise
reconnects += 1
print(f"\n SSH dropped (reconnect {reconnects}/{max_reconnects}), "
"training continues on remote...")
time.sleep(min(5 * reconnects, 30))
continue
if is_generator and next_sync is not None:
epoch = self._latest_epoch_in_log(target, log_path)
if epoch >= next_sync:
print(f" Syncing generator outputs at epoch {next_sync}...")
self._fetch_outputs(target, output_root, local_output_root)
next_sync += _GENERATOR_SYNC_INTERVAL
if not alive:
break
reconnects = 0
time.sleep(poll)
finally:
stop_streaming.set()
stream_thread.join(timeout=10)
# Read the exit code marker that the nohup wrapper wrote
raw_exit = read_remote_file(target, exit_path)
exit_code = int(raw_exit) if raw_exit is not None else -1
if exit_code != 0:
if self._remote_dir_exists(target, f"{ws}/{output_root}"):
self._fetch_outputs(target, output_root, local_output_root)
log_tail = read_remote_file(target, log_path)
snippet = "\n".join((log_tail or "").splitlines()[-30:])
raise RemoteCommandError(f"Training failed (exit {exit_code}).\n{snippet}")
# Parse the latest epoch number from the training log (used for generator sync scheduling)
@staticmethod
def _latest_epoch_in_log(target: SshTarget, log_path: str) -> int:
try:
out = ssh_output(target, f"tail -n 200 {log_path} 2>/dev/null || true")
epochs = [int(m) for m in re.findall(r"\[(\d+)/\d+\]", out)]
return max(epochs, default=0)
except RemoteCommandError:
return 0
# Download remote outputs with up to 3 retry attempts (exponential back-off)
def _fetch_outputs(self, target: SshTarget, remote_output_root: str, local_output_root: Path) -> None:
remote_path = f"{self.config.remote['workspace_dir']}/{remote_output_root}"
delay = 5
for attempt in range(1, 4):
try:
fetch_directory(target, remote_path, local_output_root)
return
except RemoteCommandError as exc:
if attempt < 3:
print(f" Download attempt {attempt} failed: {exc}")
print(f" Retrying in {delay}s...")
time.sleep(delay)
delay *= 2
else:
raise
# ── public commands ───────────────────────────────────────────────
# Full pipeline: validate configs -> find offer -> rent instance -> train -> fetch -> destroy
def run(self, config_paths: list[Path], opts: RunOptions) -> None:
resolved = []
for cp in config_paths:
cp = (self.project_root / cp).resolve() if not cp.is_absolute() else cp
if not cp.exists():
raise FileNotFoundError(f"Config not found: {cp}")
resolved.append(cp)
n = len(resolved)
_, first_output_root = self._module_for_config(resolved[0])
local_output_root = self.project_root / first_output_root
self.api.ensure_ssh_key(self.public_key)
offers = self._search_offers(opts.sort_mode, opts.region, opts.price_cap)
if opts.list_regions:
self._print_regions(offers)
return
try:
offer = self._choose_offer(offers) if opts.select_offer else offers[0]
except (OfferSelectionAborted, KeyboardInterrupt):
print("Aborted.")
return
if opts.dry_run:
print(json.dumps(offer if opts.select_offer else offers[:10], indent=2))
return
label_stem = resolved[0].stem if n == 1 else resolved[0].parent.name
template_hash_id = self._resolve_template(opts)
offer_id = int(offer["id"])
cost_per_hour = float(offer.get("dph_total", 0))
manifest = RunManifest(
created_at=datetime.now(UTC).isoformat(),
config_paths=[str(cp.relative_to(self.project_root)) for cp in resolved],
instance_id=None,
offer_id=offer_id,
ssh_host=None,
ssh_port=None,
status="creating",
remote_workspace=self.config.remote["workspace_dir"],
)
manifest_path = self._write_manifest(manifest, local_output_root)
instance_id: int | None = None
should_destroy = True
start_time = time.time()
try:
# Build the instance payload and rent the GPU
payload = self._build_instance_payload(self._label_for(label_stem),
template_hash_id=template_hash_id)
print(f"Creating instance from offer {offer_id} (${cost_per_hour:.3f}/h)...")
instance_id = self.api.create_instance(offer_id, payload)
self.api.attach_ssh_key(instance_id, self.public_key)
# Record the instance ID in the manifest immediately for debugging
manifest.instance_id = instance_id
# Wait for the instance to be SSH-ready, then update the manifest
print(f"Waiting for instance {instance_id}...")
instance = self._wait_for_instance(instance_id)
manifest.ssh_host = instance.ssh_host
manifest.ssh_port = instance.ssh_port
manifest.status = instance.actual_status
self._write_manifest(manifest, local_output_root)
print(f"Instance ready: {instance.ssh_host}:{instance.ssh_port} "
f"({instance.gpu_name}, ${instance.dph_total}/h)")
target = self._build_target(instance)
# Wait for SSH to become available before running any remote commands
print("Waiting for SSH...")
wait_for_ssh(
target,
timeout_seconds=self.config.remote["ssh_timeout_seconds"],
poll_interval_seconds=self.config.remote["poll_interval_seconds"],
)
self._prepare_remote_workspace(
target, download_data=opts.download_data, send_cropped=opts.send_cropped,
first_config=resolved[0],
)
# Run each training config sequentially on the same instance
for i, config_path in enumerate(resolved, 1):
print(f"\n[{i}/{n}] Training: {config_path.name}")
self._run_training(target, config_path, opts)
print(f"[{i}/{n}] Fetching outputs...")
_, output_root = self._module_for_config(config_path)
self._fetch_outputs(target, output_root, self.project_root / output_root)
elapsed = time.time() - start_time
manifest.status = "completed"
self._write_manifest(manifest, local_output_root)
print(f"\nAll {n} run(s) completed in {self._fmt_duration(elapsed)} "
f"(~{self._fmt_cost(elapsed, cost_per_hour)}). Manifest: {manifest_path}")
except KeyboardInterrupt:
elapsed = time.time() - start_time
manifest.status = "cancelled"
self._write_manifest(manifest, local_output_root)
print(f"\nCancelled after {self._fmt_duration(elapsed)} "
f"(~{self._fmt_cost(elapsed, cost_per_hour)}).")
except Exception:
elapsed = time.time() - start_time
manifest.status = "failed"
self._write_manifest(manifest, local_output_root)
print(f" Failed after {self._fmt_duration(elapsed)} "
f"(~{self._fmt_cost(elapsed, cost_per_hour)}).")
if opts.keep_on_failure or self.config.keep_on_failure:
should_destroy = False
raise
finally:
if instance_id is not None and should_destroy:
elapsed = time.time() - start_time
print(f"Destroying instance {instance_id} "
f"(total: {self._fmt_duration(elapsed)}, "
f"~{self._fmt_cost(elapsed, cost_per_hour)})")
self.api.destroy_instance(instance_id)
# CLI subcommand: search and display available GPU offers without renting
def offers(
self,
*,
sort_mode: str | None,
region: str | None,
price_cap: float | None,
select_offer: bool,
list_regions: bool,
limit_output: int,
) -> None:
offers = self._search_offers(sort_mode, region, price_cap)
if list_regions:
self._print_regions(offers)
return
if select_offer:
try:
offer = self._choose_offer(offers)
except (OfferSelectionAborted, KeyboardInterrupt):
print("Aborted.")
return
print(json.dumps(offer, indent=2))
else:
print(json.dumps(offers[:limit_output], indent=2))
# CLI subcommand: rent an instance and print its connection details (no training)
def up(self, *, label: str | None, opts: RunOptions | None = None) -> None:
if opts is None:
opts = RunOptions()
self.api.ensure_ssh_key(self.public_key)
template_hash_id = self._resolve_template(opts)
offer = self._search_offers()[0]
offer_id = int(offer["id"])
payload = self._build_instance_payload(
label or self._label_for("manual"), template_hash_id=template_hash_id
)
instance_id = self.api.create_instance(offer_id, payload)
self.api.attach_ssh_key(instance_id, self.public_key)
instance = self._wait_for_instance(instance_id)
print(json.dumps({
"offer_id": offer_id,
"instance_id": instance_id,
"ssh_host": instance.ssh_host,
"ssh_port": instance.ssh_port,
"gpu_name": instance.gpu_name,
"dph_total": instance.dph_total,
}, indent=2))
# CLI subcommand: print the raw instance status JSON
def status(self, instance_id: int) -> None:
print(json.dumps(self.api.show_instance(instance_id).raw, indent=2))
# CLI subcommand: destroy a running instance by ID
def down(self, instance_id: int) -> None:
self.api.destroy_instance(instance_id)
print(f"Destroyed instance {instance_id}")