import json import re import threading import time from dataclasses import asdict, dataclass from datetime import UTC, datetime from pathlib import Path from typing import Any from pipeline.config import ( PipelineConfig, read_public_key, resolve_api_key, resolve_ssh_private_key, ) from pipeline.remote import ( RemoteCommandError, SshTarget, fetch_directory, is_process_alive, read_remote_file, rsync_file, rsync_project, run_detached, shell_join, ssh, ssh_output, tail_log, wait_for_ssh, ) from pipeline.vast_api import VastApiClient, VastApiError, VastInstance # ── data classes ───────────────────────────────────────────────────── # Snapshot of a pipeline run written to disk for traceability @dataclass(slots=True) class RunManifest: created_at: str config_paths: list[str] instance_id: int | None offer_id: int | None ssh_host: str | None ssh_port: int | None status: str remote_workspace: str | None # CLI flags controlling how the pipeline run behaves @dataclass class RunOptions: download_data: bool = False send_cropped: bool = False keep_on_failure: bool = False dry_run: bool = False select_offer: bool = False select_template: bool = False template_hash: str | None = None sort_mode: str | None = None region: str | None = None price_cap: float | None = None list_regions: bool = False use_gpu: bool = True # ── constants ───────────────────────────────────────────────────────── # ISO-3166-1-alpha-2 country codes used by the --region europe filter EUROPE_REGION_CODES = { "AL", "AD", "AM", "AT", "AZ", "BA", "BE", "BG", "BY", "CH", "CY", "CZ", "DE", "DK", "EE", "ES", "FI", "FR", "GB", "GE", "GR", "HR", "HU", "IE", "IS", "IT", "LI", "LT", "LU", "LV", "MC", "MD", "ME", "MK", "MT", "NL", "NO", "PL", "PT", "RO", "RS", "SE", "SI", "SK", "SM", "TR", "UA", "VA", } # How often (in epochs) to rsync generator outputs back during training _GENERATOR_SYNC_INTERVAL = 50 # Raised when the user aborts interactive offer selection class OfferSelectionAborted(RuntimeError): pass # Raised when the user aborts interactive template selection class TemplateSelectionAborted(RuntimeError): pass # ── runner ──────────────────────────────────────────────────────────── # Orchestrates the full lifecycle: search offers -> rent GPU -> train -> fetch results -> destroy class EphemeralVastRunner: # Pre-defined Vast.ai image templates the user can pick interactively TEMPLATE_CATALOG: list[dict[str, str]] = [ { "hash_id": "661d064bbda1f2a133816b6d55da07c3", "name": "PyTorch (cuDNN Devel)", "image": "pytorch/pytorch:2.5.1-cuda12.4-cudnn9-devel", }, { "hash_id": "b9e5a8f3d4c1e7f6a2b0d8c9e5f1a3b7", "name": "PyTorch (Latest)", "image": "pytorch/pytorch:latest", }, { "hash_id": "none", "name": "Custom image (no template)", "image": "", }, ] # Resolve the project root (two levels up from this file) def __init__(self, override_path: Path | None) -> None: self.project_root = Path(__file__).resolve().parent.parent self.config = PipelineConfig.load(self.project_root, override_path) self.api = VastApiClient(resolve_api_key()) self.private_key = resolve_ssh_private_key() self.public_key = read_public_key(self.private_key) self.exclude_file = self.project_root / "pipeline" / "rsync-excludes.txt" # ── formatting helpers ──────────────────────────────────────────── # Format seconds into a compact human-readable string (e.g. "1d2h30m") @staticmethod def _fmt_duration(seconds: float) -> str: s = max(int(seconds), 0) days, s = divmod(s, 86400) hours, s = divmod(s, 3600) minutes, _ = divmod(s, 60) parts: list[str] = [] if days: parts.append(f"{days}d") if hours or days: parts.append(f"{hours}h") parts.append(f"{minutes}m") return "".join(parts) # Estimate cost from wall-clock time and the offer's $/h rate @staticmethod def _fmt_cost(elapsed_seconds: float, cost_per_hour: float) -> str: return f"${(elapsed_seconds / 3600) * cost_per_hour:.2f}" # ── offer search & selection ────────────────────────────────────── # Build a sort tuple that ranks offers by the chosen mode (performance / price / dlp_per_dollar) @staticmethod def _offer_sort_key(offer: dict, sort_mode: str) -> tuple: price = float(offer.get("dph_total", float("inf"))) dlperf = float(offer.get("dlperf", float("-inf"))) reliability = float(offer.get("reliability2", 0.0)) if sort_mode == "price": return (price, -dlperf, -reliability) if sort_mode == "dlp_per_dollar": per_dollar = dlperf / price if price > 0 else 0.0 return (-per_dollar, -reliability, price) return (-dlperf, price, -reliability) # Match a user-friendly region string (e.g. "europe") against the offer's geolocation code @staticmethod def _region_matches(offered: str, requested: str) -> bool: req = requested.strip().lower() if not req: return True if req == "europe": code = offered.rsplit(", ", 1)[-1].strip() if ", " in offered else "" return code in EUROPE_REGION_CODES return req in offered.lower() # Query Vast.ai for matching offers, optionally filter by region, then sort def _search_offers( self, sort_mode: str | None = None, region: str | None = None, price_cap: float | None = None, ) -> list[dict]: query = self.config.build_offer_query(price_cap=price_cap) offers = self.api.search_offers(query) if not offers: raise VastApiError("No Vast offers matched the configured filters.") if region: offers = [o for o in offers if self._region_matches(o.get("geolocation") or "", region)] if not offers: raise VastApiError(f"No offers matched region filter: {region!r}") mode = sort_mode or self.config.search.get("sort_mode", "performance") return sorted(offers, key=lambda o: self._offer_sort_key(o, mode)) # Print a summary of available regions and their offer counts @staticmethod def _print_regions(offers: list[dict]) -> None: counts: dict[str, int] = {} for o in offers: region = o.get("geolocation") or "Unknown" counts[region] = counts.get(region, 0) + 1 for region, count in sorted(counts.items(), key=lambda kv: (-kv[1], kv[0])): print(f"{region}: {count}") # Render one page of offers with GPU specs, price, and reliability def _print_offer_page(self, offers: list[dict], *, page: int, page_size: int) -> None: start = page * page_size end = min(start + page_size, len(offers)) print(f"Offers {start + 1}–{end} of {len(offers)}") for idx in range(start, end): o = offers[idx] vram_gb = o.get("gpu_ram", 0) / 1024 duration = self._fmt_duration(float(o.get("duration") or 0)) print( f" [{idx + 1}] {o.get('gpu_name')} " f"gpus={o.get('num_gpus')} " f"vram={vram_gb:.0f}GB " f"dlperf={o.get('dlperf', 0):.1f} " f"$/h={o.get('dph_total', 0):.3f} " f"reliability={o.get('reliability2', 0):.3f} " f"avail={duration} " f"region={o.get('geolocation')} " f"id={o.get('id')}" ) # Paginated interactive prompt — user picks an offer or aborts def _choose_offer(self, offers: list[dict]) -> dict: page, page_size = 0, 10 last_page = max((len(offers) - 1) // page_size, 0) while True: self._print_offer_page(offers, page=page, page_size=page_size) raw = input("Select offer [number / n / p / q]: ").strip().lower() if raw == "n": page = min(page + 1, last_page) elif raw == "p": page = max(page - 1, 0) elif raw == "q": raise OfferSelectionAborted elif raw.isdigit(): idx = int(raw) - 1 if 0 <= idx < len(offers): return offers[idx] print(" Out of range.") else: print(" Invalid input.") # ── template selection ──────────────────────────────────────────── # Decide which Vast template hash to use: CLI flag, interactive pick, or config default def _resolve_template(self, opts: RunOptions) -> str | None: if opts.template_hash: return opts.template_hash if opts.select_template: try: return self._choose_template() except (TemplateSelectionAborted, KeyboardInterrupt): print("Template selection aborted, using default image.") return None return self.config.instance.get("template_hash_id") # Interactive prompt — user picks a Docker image template def _choose_template(self) -> str | None: print("Available templates:") for idx, tpl in enumerate(self.TEMPLATE_CATALOG, 1): print(f" [{idx}] {tpl['name']} ({tpl['image'] or 'config image'})") while True: raw = input("Select template [number / q]: ").strip().lower() if raw == "q": raise TemplateSelectionAborted if raw.isdigit(): idx = int(raw) - 1 if 0 <= idx < len(self.TEMPLATE_CATALOG): tpl = self.TEMPLATE_CATALOG[idx] if tpl["hash_id"] == "none": return None if tpl["image"]: self.config.instance["image"] = tpl["image"] print(f" Selected: {tpl['name']}") return tpl["hash_id"] print(" Invalid input.") # ── instance lifecycle ──────────────────────────────────────────── # Build a unique instance label from a prefix, stem, and UTC timestamp def _label_for(self, stem: str) -> str: ts = datetime.now(UTC).strftime("%Y%m%d-%H%M%S") return f"{self.config.instance['label_prefix']}-{stem}-{ts}" # Assemble the JSON payload sent to Vast.ai when renting an instance def _build_instance_payload(self, label: str, *, template_hash_id: str | None = None) -> dict[str, Any]: inst = self.config.instance payload: dict[str, Any] = { "image": inst["image"], "disk": inst["disk_gb"], "runtype": "ssh_direc ssh_proxy", "target_state": inst["target_state"], "label": label, } if template_hash_id: payload["template_hash_id"] = template_hash_id return payload # Poll the API until the instance is running and has SSH credentials def _wait_for_instance(self, instance_id: int) -> VastInstance: deadline = time.time() + self.config.remote["ssh_timeout_seconds"] poll = self.config.remote["poll_interval_seconds"] while time.time() < deadline: instance = self.api.show_instance(instance_id) if instance.actual_status == "running" and instance.ssh_host and instance.ssh_port: return instance print(f" instance status: {instance.actual_status or 'pending'}...") time.sleep(poll) raise VastApiError(f"Timed out waiting for instance {instance_id} to be SSH-ready.") # Build an SshTarget from the running instance's connection details def _build_target(self, instance: VastInstance) -> SshTarget: return SshTarget( user=self.config.remote["ssh_user"], host=instance.ssh_host or instance.public_ipaddr or "", port=int(instance.ssh_port or 22), private_key=self.private_key, ) # ── manifest ────────────────────────────────────────────────────── # Persist the run manifest as a timestamped JSON file under the output directory def _write_manifest(self, manifest: RunManifest, local_output_root: Path) -> Path: d = local_output_root / "pipeline" d.mkdir(parents=True, exist_ok=True) ts = manifest.created_at.replace(":", "").replace("-", "") path = d / f"{ts}.json" with open(path, "w", encoding="utf-8") as fh: json.dump(asdict(manifest), fh, indent=2) return path # ── remote workspace setup ──────────────────────────────────────── # Upload project files, bootstrap the venv, and ensure training data is present def _prepare_remote_workspace( self, target: SshTarget, *, download_data: bool, send_cropped: bool, first_config: Path | None = None, ) -> None: ws = self.config.remote["workspace_dir"] poll = self.config.remote["poll_interval_seconds"] print(" Syncing project files...") ssh(target, shell_join(["mkdir", "-p", ws])) rsync_project(self.project_root, target, ws, self.exclude_file) print(" Bootstrapping remote environment...") ssh(target, shell_join(["bash", "-lc", f"cd {ws} && bash pipeline/scripts/bootstrap_env.sh"])) if send_cropped and first_config is not None: # Determine which cropped subdirectory to send based on the first config first_mod, _ = self._module_for_config(first_config) crop_subdir = "generator" if first_mod.startswith("generator/") else "classifier" local_zip = self.project_root / f"cropped_{crop_subdir}.zip" if not local_zip.exists(): print(f" Warning: --send-cropped set but cropped_{crop_subdir}.zip not found.") print(f" Create it with: zip -r cropped_{crop_subdir}.zip cropped/{crop_subdir}/") else: size_mb = local_zip.stat().st_size / (1024 * 1024) print(f" Sending cropped_{crop_subdir}.zip ({size_mb:.0f} MB)...") rsync_file(local_zip, target, ws) print(" Unzipping on remote...") ssh(target, shell_join(["bash", "-lc", f"cd {ws} && unzip -q -o cropped_{crop_subdir}.zip"])) print(f" Pre-cropped {crop_subdir} images ready.") data_ready = self._remote_data_exists(target, f"{ws}/data") if download_data or not data_ready: # Download the dataset in a detached process so the SSH timeout doesn't kill it print(" Downloading dataset from HuggingFace...") log = f"{ws}/.pipeline_logs/fetch_ds.log" ssh(target, f"mkdir -p {ws}/.pipeline_logs") pid = run_detached( target, f"cd {ws} && source .venv/bin/activate && python3 classifier/tools/fetch_ds.py", log, ) reconnects = 0 # Poll until the fetch process exits, with reconnect logic for SSH drops while True: try: if not is_process_alive(target, pid): break reconnects = 0 except RemoteCommandError as exc: if "Command failed (255)" not in str(exc) or reconnects >= 10: raise reconnects += 1 time.sleep(min(5 * reconnects, 30)) continue time.sleep(poll) print(" Dataset ready.") # ── config routing ──────────────────────────────────────────────── # Determine which training script and output dir to use based on config location def _module_for_config(self, config_path: Path) -> tuple[str, str]: try: rel = config_path.resolve().relative_to(self.project_root) except ValueError: rel = config_path if rel.parts[0] == "generator": return "generator/run.py", "generator/outputs" return "classifier/run.py", "classifier/outputs" # ── remote directory checks ─────────────────────────────────────── # Check whether a directory exists on the remote host def _remote_dir_exists(self, target: SshTarget, path: str) -> bool: out = ssh_output(target, f"if [ -d {path} ]; then echo yes; else echo no; fi").strip() return out == "yes" # Check that the dataset is populated (not just an empty directory) def _remote_data_exists(self, target: SshTarget, data_dir: str) -> bool: out = ssh_output(target, f"if [ -d {data_dir}/wiki ]; then echo yes; else echo no; fi").strip() return out == "yes" # ── training ───────────────────────────────────────────────────── # Launch the training script inside nohup, stream logs, and handle reconnects def _run_training(self, target: SshTarget, config_path: Path, opts: RunOptions) -> None: ws = self.config.remote["workspace_dir"] config_rel = config_path.resolve().relative_to(self.project_root) run_script, output_root = self._module_for_config(config_path) is_generator = run_script.startswith("generator/") run_cmd = shell_join(["python3", "-u", run_script, str(config_rel), "--output-root", output_root]) if opts.use_gpu: run_cmd += " --use-gpu" log_dir = f"{ws}/.pipeline_logs" log_path = f"{log_dir}/{config_path.stem}.log" exit_path = f"{log_dir}/{config_path.stem}.exit" ssh(target, f"mkdir -p {log_dir}") # Run training under nohup; write the exit code to a marker file when done pid = run_detached( target, f"cd {ws} && source .venv/bin/activate && {run_cmd}; echo $? > {exit_path}", log_path, ) print(f" Training started (PID {pid})") poll = self.config.remote["poll_interval_seconds"] max_reconnects = 10 reconnects = 0 next_sync = _GENERATOR_SYNC_INTERVAL if is_generator else None local_output_root = self.project_root / output_root stop_streaming = threading.Event() # Background thread: tail the remote log and print tqdm-aware output def _stream_worker() -> None: tail_proc = None last_was_progress = False try: time.sleep(1) # wait for the log file to be created tail_proc = tail_log(target, log_path) for raw in tail_proc.stdout: if stop_streaming.is_set(): break # tqdm uses \r for in-place updates; take only the last segment line = raw.rstrip("\r\n").rsplit("\r", 1)[-1].rstrip() if not line: continue is_progress = "it/s]" in line or "%|" in line if is_progress: print(f"\r {line}", end="", flush=True) last_was_progress = True else: if last_was_progress: print(flush=True) last_was_progress = False print(f" {line}", flush=True) except Exception: pass finally: if last_was_progress: print(flush=True) if tail_proc is not None: tail_proc.terminate() tail_proc.wait(timeout=5) stream_thread = threading.Thread(target=_stream_worker, daemon=True) stream_thread.start() try: # Main polling loop: check if the process is alive, periodically sync generator outputs while True: try: alive = is_process_alive(target, pid) except RemoteCommandError as exc: if "Command failed (255)" not in str(exc) or reconnects >= max_reconnects: raise reconnects += 1 print(f"\n SSH dropped (reconnect {reconnects}/{max_reconnects}), " "training continues on remote...") time.sleep(min(5 * reconnects, 30)) continue if is_generator and next_sync is not None: epoch = self._latest_epoch_in_log(target, log_path) if epoch >= next_sync: print(f" Syncing generator outputs at epoch {next_sync}...") self._fetch_outputs(target, output_root, local_output_root) next_sync += _GENERATOR_SYNC_INTERVAL if not alive: break reconnects = 0 time.sleep(poll) finally: stop_streaming.set() stream_thread.join(timeout=10) # Read the exit code marker that the nohup wrapper wrote raw_exit = read_remote_file(target, exit_path) exit_code = int(raw_exit) if raw_exit is not None else -1 if exit_code != 0: if self._remote_dir_exists(target, f"{ws}/{output_root}"): self._fetch_outputs(target, output_root, local_output_root) log_tail = read_remote_file(target, log_path) snippet = "\n".join((log_tail or "").splitlines()[-30:]) raise RemoteCommandError(f"Training failed (exit {exit_code}).\n{snippet}") # Parse the latest epoch number from the training log (used for generator sync scheduling) @staticmethod def _latest_epoch_in_log(target: SshTarget, log_path: str) -> int: try: out = ssh_output(target, f"tail -n 200 {log_path} 2>/dev/null || true") epochs = [int(m) for m in re.findall(r"\[(\d+)/\d+\]", out)] return max(epochs, default=0) except RemoteCommandError: return 0 # Download remote outputs with up to 3 retry attempts (exponential back-off) def _fetch_outputs(self, target: SshTarget, remote_output_root: str, local_output_root: Path) -> None: remote_path = f"{self.config.remote['workspace_dir']}/{remote_output_root}" delay = 5 for attempt in range(1, 4): try: fetch_directory(target, remote_path, local_output_root) return except RemoteCommandError as exc: if attempt < 3: print(f" Download attempt {attempt} failed: {exc}") print(f" Retrying in {delay}s...") time.sleep(delay) delay *= 2 else: raise # ── public commands ─────────────────────────────────────────────── # Full pipeline: validate configs -> find offer -> rent instance -> train -> fetch -> destroy def run(self, config_paths: list[Path], opts: RunOptions) -> None: resolved = [] for cp in config_paths: cp = (self.project_root / cp).resolve() if not cp.is_absolute() else cp if not cp.exists(): raise FileNotFoundError(f"Config not found: {cp}") resolved.append(cp) n = len(resolved) _, first_output_root = self._module_for_config(resolved[0]) local_output_root = self.project_root / first_output_root self.api.ensure_ssh_key(self.public_key) offers = self._search_offers(opts.sort_mode, opts.region, opts.price_cap) if opts.list_regions: self._print_regions(offers) return try: offer = self._choose_offer(offers) if opts.select_offer else offers[0] except (OfferSelectionAborted, KeyboardInterrupt): print("Aborted.") return if opts.dry_run: print(json.dumps(offer if opts.select_offer else offers[:10], indent=2)) return label_stem = resolved[0].stem if n == 1 else resolved[0].parent.name template_hash_id = self._resolve_template(opts) offer_id = int(offer["id"]) cost_per_hour = float(offer.get("dph_total", 0)) manifest = RunManifest( created_at=datetime.now(UTC).isoformat(), config_paths=[str(cp.relative_to(self.project_root)) for cp in resolved], instance_id=None, offer_id=offer_id, ssh_host=None, ssh_port=None, status="creating", remote_workspace=self.config.remote["workspace_dir"], ) manifest_path = self._write_manifest(manifest, local_output_root) instance_id: int | None = None should_destroy = True start_time = time.time() try: # Build the instance payload and rent the GPU payload = self._build_instance_payload(self._label_for(label_stem), template_hash_id=template_hash_id) print(f"Creating instance from offer {offer_id} (${cost_per_hour:.3f}/h)...") instance_id = self.api.create_instance(offer_id, payload) self.api.attach_ssh_key(instance_id, self.public_key) # Record the instance ID in the manifest immediately for debugging manifest.instance_id = instance_id # Wait for the instance to be SSH-ready, then update the manifest print(f"Waiting for instance {instance_id}...") instance = self._wait_for_instance(instance_id) manifest.ssh_host = instance.ssh_host manifest.ssh_port = instance.ssh_port manifest.status = instance.actual_status self._write_manifest(manifest, local_output_root) print(f"Instance ready: {instance.ssh_host}:{instance.ssh_port} " f"({instance.gpu_name}, ${instance.dph_total}/h)") target = self._build_target(instance) # Wait for SSH to become available before running any remote commands print("Waiting for SSH...") wait_for_ssh( target, timeout_seconds=self.config.remote["ssh_timeout_seconds"], poll_interval_seconds=self.config.remote["poll_interval_seconds"], ) self._prepare_remote_workspace( target, download_data=opts.download_data, send_cropped=opts.send_cropped, first_config=resolved[0], ) # Run each training config sequentially on the same instance for i, config_path in enumerate(resolved, 1): print(f"\n[{i}/{n}] Training: {config_path.name}") self._run_training(target, config_path, opts) print(f"[{i}/{n}] Fetching outputs...") _, output_root = self._module_for_config(config_path) self._fetch_outputs(target, output_root, self.project_root / output_root) elapsed = time.time() - start_time manifest.status = "completed" self._write_manifest(manifest, local_output_root) print(f"\nAll {n} run(s) completed in {self._fmt_duration(elapsed)} " f"(~{self._fmt_cost(elapsed, cost_per_hour)}). Manifest: {manifest_path}") except KeyboardInterrupt: elapsed = time.time() - start_time manifest.status = "cancelled" self._write_manifest(manifest, local_output_root) print(f"\nCancelled after {self._fmt_duration(elapsed)} " f"(~{self._fmt_cost(elapsed, cost_per_hour)}).") except Exception: elapsed = time.time() - start_time manifest.status = "failed" self._write_manifest(manifest, local_output_root) print(f" Failed after {self._fmt_duration(elapsed)} " f"(~{self._fmt_cost(elapsed, cost_per_hour)}).") if opts.keep_on_failure or self.config.keep_on_failure: should_destroy = False raise finally: if instance_id is not None and should_destroy: elapsed = time.time() - start_time print(f"Destroying instance {instance_id} " f"(total: {self._fmt_duration(elapsed)}, " f"~{self._fmt_cost(elapsed, cost_per_hour)})") self.api.destroy_instance(instance_id) # CLI subcommand: search and display available GPU offers without renting def offers( self, *, sort_mode: str | None, region: str | None, price_cap: float | None, select_offer: bool, list_regions: bool, limit_output: int, ) -> None: offers = self._search_offers(sort_mode, region, price_cap) if list_regions: self._print_regions(offers) return if select_offer: try: offer = self._choose_offer(offers) except (OfferSelectionAborted, KeyboardInterrupt): print("Aborted.") return print(json.dumps(offer, indent=2)) else: print(json.dumps(offers[:limit_output], indent=2)) # CLI subcommand: rent an instance and print its connection details (no training) def up(self, *, label: str | None, opts: RunOptions | None = None) -> None: if opts is None: opts = RunOptions() self.api.ensure_ssh_key(self.public_key) template_hash_id = self._resolve_template(opts) offer = self._search_offers()[0] offer_id = int(offer["id"]) payload = self._build_instance_payload( label or self._label_for("manual"), template_hash_id=template_hash_id ) instance_id = self.api.create_instance(offer_id, payload) self.api.attach_ssh_key(instance_id, self.public_key) instance = self._wait_for_instance(instance_id) print(json.dumps({ "offer_id": offer_id, "instance_id": instance_id, "ssh_host": instance.ssh_host, "ssh_port": instance.ssh_port, "gpu_name": instance.gpu_name, "dph_total": instance.dph_total, }, indent=2)) # CLI subcommand: print the raw instance status JSON def status(self, instance_id: int) -> None: print(json.dumps(self.api.show_instance(instance_id).raw, indent=2)) # CLI subcommand: destroy a running instance by ID def down(self, instance_id: int) -> None: self.api.destroy_instance(instance_id) print(f"Destroyed instance {instance_id}")