Clean state

This commit is contained in:
Johnny Fernandes
2026-04-30 01:25:39 +01:00
commit bb3dfb92d5
266 changed files with 37043 additions and 0 deletions
+164
View File
@@ -0,0 +1,164 @@
# Pipeline
Orchestrates ephemeral Vast.ai GPU instances: searches for an offer, creates the instance, syncs the project, trains, downloads `outputs/`, and destroys the instance automatically. Generator runs also rsync `generator/outputs/` every 50 epochs while training is still running.
## One-time setup
Create `pipeline/.env`:
```dotenv
VAST_API_KEY=your-vast-api-key
VAST_SSH_PRIVATE_KEY=/home/you/.ssh/id_ed25519 # optional, this is the default
```
The matching `.pub` file must exist alongside the private key. The pipeline registers it with Vast.ai automatically if it isn't there yet.
## Commands
### `run` — train on a remote GPU and fetch results
```
python -m pipeline run <config...> [options]
```
Accepts one or more config paths, or a single directory (all `*.json` inside, sorted). Duplicate configs (identical training settings after resolving `extends` and `shared.json`) are skipped automatically.
| Flag | Default | Description |
|------|---------|-------------|
| `configs` | *(required)* | One or more config paths, or a directory of JSON configs |
| `--download-data` | off | Download the DFF dataset via HuggingFace on the remote before training |
| `--send-cropped` | off | Rsync local `cropped/{classifier,generator}/` to remote (picks subdirectory based on config) |
| `--select-offer` | off | Interactively browse and pick the GPU offer |
| `--sort` | config | Ranking mode: `price`, `performance`, or `dlp_per_dollar` |
| `--region TEXT` | any | Filter by region, e.g. `europe`, `Portugal`, `US` |
| `--price FLOAT` | config | Max hourly price cap in USD |
| `--dry-run` | off | Print matching offers without creating an instance |
| `--keep-on-failure` | off | Do not destroy the instance if training fails |
| `--no-gpu` | off | Disable GPU training on remote (use CPU instead) |
| `--select-template` | off | Interactively choose a Vast.ai Docker template |
| `--template HASH` | config | Use a specific template hash ID |
| `--pipeline-config PATH` | none | JSON file that overrides `pipeline/defaults/vast.json` |
**Examples:**
```bash
# Cheapest available RTX 3090 in Europe, download data on remote
python -m pipeline run configs/resnet18.json --region europe --download-data
# Browse offers interactively, sort by price
python -m pipeline run configs/resnet18.json --select-offer --sort price
# Run all configs in a directory sequentially on one instance
python -m pipeline run configs/phase2/ --region europe
# See what offers would be selected without spending money
python -m pipeline run configs/resnet18.json --dry-run --region europe
# Keep the instance alive if something goes wrong (for debugging)
python -m pipeline run configs/resnet18.json --keep-on-failure
# Cap price at $0.12/h
python -m pipeline run configs/resnet18.json --price 0.12
```
### `offers` — inspect available GPU offers
```
python -m pipeline offers [options]
```
| Flag | Default | Description |
|------|---------|-------------|
| `--sort` | config | Ranking mode: `price`, `performance`, or `dlp_per_dollar` |
| `--region TEXT` | any | Region filter |
| `--price FLOAT` | config | Max hourly price cap |
| `--select-offer` | off | Interactive offer picker (prints the selected offer as JSON) |
| `--list-regions` | off | Print a count of available offers per region and exit |
| `--limit-output INT` | 10 | How many offers to print |
| `--pipeline-config PATH` | none | Pipeline config override |
**Examples:**
```bash
# See the 20 best-value offers under $0.15/h in Europe
python -m pipeline offers --region europe --price 0.15 --limit-output 20
# List which regions have matching GPUs
python -m pipeline offers --list-regions
# Interactive picker — useful before committing to a run
python -m pipeline offers --select-offer --sort price
```
### `up` — create an instance without training
Spins up an instance and prints SSH connection details. Useful for manual experiments or debugging.
```
python -m pipeline up [options]
```
| Flag | Default | Description |
|------|---------|-------------|
| `--label TEXT` | auto | Optional label for the instance |
| `--select-template` | off | Interactively choose a Vast.ai Docker template |
| `--template HASH` | config | Use a specific template hash ID |
| `--pipeline-config PATH` | none | Pipeline config override |
```bash
python -m pipeline up
python -m pipeline up --label my-debug-session
```
### `status` — show instance details
```
python -m pipeline status <instance_id> [--pipeline-config PATH]
```
### `down` — destroy an instance
```
python -m pipeline down <instance_id> [--pipeline-config PATH]
```
## Pipeline config overrides
Pass `--pipeline-config my_overrides.json` to override any field from `pipeline/defaults/vast.json`. Only the fields you specify are changed; the rest keep their defaults (deep-merged). Useful for switching GPU types or raising the price cap for a single run without editing defaults.
**Example — allow RTX 4090, higher price cap:**
```json
{
"search": {
"gpu_names": ["RTX 4090"],
"max_dph_total": 0.45
}
}
```
**Key fields in `pipeline/defaults/vast.json`:**
| Section | Key | Default | Meaning |
|---------|-----|---------|---------|
| `search` | `gpu_names` | `["RTX 3090", "RTX 3090 Ti"]` | Accepted GPU models |
| `search` | `max_dph_total` | `0.40` | Max price per hour |
| `search` | `sort_mode` | `"dlp_per_dollar"` | Default ranking (`price`, `performance`, or `dlp_per_dollar`) |
| `search` | `min_reliability` | `0.98` | Minimum host reliability score |
| `instance` | `disk_gb` | `48` | Disk size provisioned on the instance |
| `instance` | `image` | `"vastai/pytorch:latest"` | Docker image |
| `remote` | `workspace_dir` | `"/workspace/DRL_PROJ"` | Remote working directory |
| `remote` | `ssh_timeout_seconds` | `900` | How long to wait for SSH to become available |
## Full workflow example
```bash
# 1. Check what's available and how much it costs
python -m pipeline offers --region europe --list-regions
python -m pipeline offers --region europe --sort price --limit-output 20
# 2. Run training (auto-selects best offer, downloads data if needed)
python -m pipeline run configs/resnet18.json --region europe --download-data
# 3. Results land in classifier/outputs/ automatically
```
+2
View File
@@ -0,0 +1,2 @@
"""Ephemeral Vast.ai training pipeline."""
+6
View File
@@ -0,0 +1,6 @@
from pipeline.cli import main
if __name__ == "__main__":
raise SystemExit(main())
+130
View File
@@ -0,0 +1,130 @@
import argparse
from pathlib import Path
from pipeline.orchestrator import EphemeralVastRunner, RunOptions
# Accept one or more config files, or a single directory (all *.json inside, sorted)
def _resolve_configs(raw: list[str]) -> list[Path]:
if len(raw) == 1 and Path(raw[0]).is_dir():
configs = sorted(Path(raw[0]).glob("*.json"))
if not configs:
raise ValueError(f"No JSON configs found in directory: {raw[0]}")
return configs
return [Path(p) for p in raw]
# Build the argparse CLI with subcommands: offers, run, up, status, down
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="Ephemeral Vast.ai training pipeline.")
subparsers = parser.add_subparsers(dest="command", required=True)
# Shared offer-related flags used by both "offers" and "run" subcommands
def add_offer_options(p: argparse.ArgumentParser) -> None:
p.add_argument("--pipeline-config", default=None, help="Optional JSON file with pipeline overrides.")
p.add_argument("--sort", choices=["price", "performance", "dlp_per_dollar"], default=None, help="Override offer ranking mode.")
p.add_argument("--region", default=None, help="Filter offers by region (e.g. 'europe', 'Portugal').")
p.add_argument("--price", type=float, default=None, help="Override max hourly price cap in USD.")
p.add_argument("--select-offer", action="store_true", help="Interactively choose an offer.")
# Shared template-related flags used by "run" and "up" subcommands
def add_template_options(p: argparse.ArgumentParser) -> None:
p.add_argument("--select-template", action="store_true", help="Interactively choose a Vast.ai template.")
p.add_argument("--template", default=None, help="Template hash ID to use for instance creation.")
# ── offers ──────────────────────────────────────────────────────────
offers_parser = subparsers.add_parser("offers", help="Inspect available Vast offers.")
add_offer_options(offers_parser)
offers_parser.add_argument("--list-regions", action="store_true", help="List matching regions and exit.")
offers_parser.add_argument("--limit-output", type=int, default=10, help="How many offers to print (default: 10).")
# ── run ─────────────────────────────────────────────────────────────
run_parser = subparsers.add_parser("run", help="Create instance, train, fetch outputs, destroy it.")
run_parser.add_argument("configs", nargs="+", help="One or more config paths, or a directory of JSON configs.")
add_offer_options(run_parser)
add_template_options(run_parser)
run_parser.add_argument("--download-data", action="store_true", help="Force download dataset via HuggingFace on the remote before training (auto-downloads if missing).")
run_parser.add_argument("--send-cropped", action="store_true", help="Rsync local cropped/ subdirectory to remote before training (sends only classifier or generator based on config).")
run_parser.add_argument("--keep-on-failure", action="store_true", help="Keep instance on failure.")
run_parser.add_argument("--dry-run", action="store_true", help="Search and print offers without creating instance.")
run_parser.add_argument("--no-gpu", action="store_true", help="Disable GPU training on remote (use CPU instead).")
# ── up ──────────────────────────────────────────────────────────────
up_parser = subparsers.add_parser("up", help="Create instance and print SSH details.")
up_parser.add_argument("--pipeline-config", default=None, help="Optional JSON file with pipeline overrides.")
add_template_options(up_parser)
up_parser.add_argument("--label", default=None, help="Optional label for the instance.")
# ── status ──────────────────────────────────────────────────────────
status_parser = subparsers.add_parser("status", help="Show instance details.")
status_parser.add_argument("instance_id", type=int)
status_parser.add_argument("--pipeline-config", default=None, help="Optional JSON file with pipeline overrides.")
# ── down ────────────────────────────────────────────────────────────
down_parser = subparsers.add_parser("down", help="Destroy an instance.")
down_parser.add_argument("instance_id", type=int)
down_parser.add_argument("--pipeline-config", default=None, help="Optional JSON file with pipeline overrides.")
return parser
# Parse CLI args and dispatch to the appropriate runner method
def main(argv=None) -> int:
parser = build_parser()
args = parser.parse_args(argv)
override = Path(args.pipeline_config) if getattr(args, "pipeline_config", None) else None
runner = EphemeralVastRunner(override)
# ── dispatch ────────────────────────────────────────────────────────
if args.command == "offers":
runner.offers(
sort_mode=args.sort,
region=args.region,
price_cap=args.price,
select_offer=args.select_offer,
list_regions=args.list_regions,
limit_output=args.limit_output,
)
return 0
if args.command == "run":
opts = RunOptions(
download_data=args.download_data,
send_cropped=args.send_cropped,
keep_on_failure=args.keep_on_failure,
dry_run=args.dry_run,
select_offer=args.select_offer,
select_template=args.select_template,
template_hash=args.template,
sort_mode=args.sort,
region=args.region,
price_cap=args.price,
use_gpu=not args.no_gpu, # Default to True, disable with --no-gpu
)
runner.run(_resolve_configs(args.configs), opts)
return 0
if args.command == "up":
up_opts = RunOptions(
select_template=getattr(args, "select_template", False),
template_hash=getattr(args, "template", None),
)
runner.up(label=args.label, opts=up_opts)
return 0
if args.command == "status":
runner.status(args.instance_id)
return 0
if args.command == "down":
runner.down(args.instance_id)
return 0
parser.error(f"Unknown command: {args.command}")
return 2
+182
View File
@@ -0,0 +1,182 @@
import json
import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
# Load environment variables from pipeline/.env if present
# This file is intentionally optional and gitignored
# Expected keys:
# VAST_API_KEY=<your-api-key>
# VAST_SSH_PRIVATE_KEY=/home/your-user/.ssh/id_ed25519
def load_dotenv(dotenv_path: Path) -> None:
if not dotenv_path.exists():
return
for raw_line in dotenv_path.read_text(encoding="utf-8").splitlines():
line = raw_line.strip()
if not line or line.startswith("#"):
continue
if "=" not in line:
continue
key, value = line.split("=", 1)
key = key.strip()
value = value.strip().strip("'").strip('"')
os.environ.setdefault(key, value)
# Default offer search settings
DEFAULT_SEARCH: dict[str, Any] = {
"limit": 100,
"order_by": "dlperf",
"order_direction": "desc",
"sort_mode": "performance",
"offer_type": "ondemand",
"verified_only": True,
"rentable": True,
"rented": False,
"num_gpus": 1,
"min_reliability": 0.98,
"min_cuda_ram_mb": 0,
"min_cpu_ram_mb": 0,
"min_disk_space_gb": 0,
"min_direct_port_count": 1,
"max_dph_total": 0.20,
"gpu_names": ["RTX 3090", "RTX 3090 Ti"],
"countries_exclude": [],
}
# Default instance settings
DEFAULT_INSTANCE: dict[str, Any] = {
"image": "vastai/pytorch:latest",
"disk_gb": 48,
"target_state": "running",
"label_prefix": "drl-proj",
"template_hash_id": None,
}
# Default remote SSH/workspace settings
DEFAULT_REMOTE: dict[str, Any] = {
"ssh_user": "root",
"workspace_dir": "/workspace/DRL_PROJ",
"remote_data_dir": "/workspace/data/DeepFakeFace",
"remote_output_root": "classifier/outputs",
"ssh_timeout_seconds": 900,
"poll_interval_seconds": 10,
}
# Default local transfer paths
DEFAULT_TRANSFER: dict[str, Any] = {
"local_output_dir": "classifier/outputs",
"local_data_dir": "data",
}
# Merge nested dictionaries recursively
# Used when a user override file should replace only specific nested keys
# (for example, overriding just search.max_dph_total without redefining search)
def _deep_merge(base: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]:
result = dict(base)
for key, value in override.items():
if isinstance(value, dict) and isinstance(result.get(key), dict):
result[key] = _deep_merge(result[key], value)
else:
result[key] = value
return result
# Runtime pipeline config container and loader
@dataclass(slots=True)
class PipelineConfig:
search: dict[str, Any] = field(default_factory=lambda: dict(DEFAULT_SEARCH))
instance: dict[str, Any] = field(default_factory=lambda: dict(DEFAULT_INSTANCE))
remote: dict[str, Any] = field(default_factory=lambda: dict(DEFAULT_REMOTE))
transfer: dict[str, Any] = field(default_factory=lambda: dict(DEFAULT_TRANSFER))
keep_on_failure: bool = False
# Load defaults, apply overrides, and return a PipelineConfig instance
@classmethod
def load(cls, project_root: Path, override_path: Path | None = None) -> "PipelineConfig":
load_dotenv(project_root / "pipeline" / ".env")
# Load default settings from vast.json
defaults_path = project_root / "pipeline" / "defaults" / "vast.json"
with open(defaults_path, encoding="utf-8") as handle:
raw = json.load(handle)
# Apply user-provided override config (deep merge keeps unspecified defaults)
if override_path is not None:
with open(override_path, encoding="utf-8") as handle:
raw = _deep_merge(raw, json.load(handle))
return cls(
search=raw.get("search", {}),
instance=raw.get("instance", {}),
remote=raw.get("remote", {}),
transfer=raw.get("transfer", {}),
keep_on_failure=raw.get("keep_on_failure", False),
)
# Convert the internal "search" config shape into Vast.ai query format
def build_offer_query(self, *, price_cap: float | None = None) -> dict[str, Any]:
s = self.search
query: dict[str, Any] = {
"limit": s["limit"],
"order": [[s["order_by"], s["order_direction"]]],
"type": s["offer_type"],
"verified": {"eq": s["verified_only"]},
"rentable": {"eq": s["rentable"]},
"rented": {"eq": s["rented"]},
"num_gpus": {"gte": s["num_gpus"]},
"reliability2": {"gte": s["min_reliability"]},
"gpu_ram": {"gte": s["min_cuda_ram_mb"]},
"cpu_ram": {"gte": s["min_cpu_ram_mb"]},
"disk_space": {"gte": s["min_disk_space_gb"]},
"direct_port_count": {"gte": s["min_direct_port_count"]},
}
# CLI-provided price cap takes precedence over config default
max_price = price_cap or s.get("max_dph_total")
if max_price is not None:
query["dph_total"] = {"lte": max_price}
gpu_names = s.get("gpu_names", [])
if gpu_names:
query["gpu_name"] = {"in": gpu_names}
excluded = s.get("countries_exclude", [])
if excluded:
query["geolocation"] = {"notin": excluded}
return query
# Resolve the local directory where training outputs are saved
def local_output_path(self, project_root: Path) -> Path:
return (project_root / self.transfer["local_output_dir"]).resolve()
# Resolve the local directory where datasets are stored
def local_data_path(self, project_root: Path) -> Path:
return (project_root / self.transfer["local_data_dir"]).resolve()
# Read the Vast.ai API key from the environment, failing fast if unset
def resolve_api_key() -> str:
api_key = os.environ.get("VAST_API_KEY")
if not api_key:
raise RuntimeError("VAST_API_KEY is not set.")
return api_key
# Resolve the local SSH private key path, defaulting to ~/.ssh/id_ed25519
def resolve_ssh_private_key() -> Path:
raw = os.environ.get("VAST_SSH_PRIVATE_KEY", "~/.ssh/id_ed25519")
path = Path(raw).expanduser()
if not path.exists():
raise RuntimeError(
f"SSH private key not found at {path}. Set VAST_SSH_PRIVATE_KEY to the correct path."
)
return path
# Read the corresponding .pub file for a given private key
def read_public_key(private_key_path: Path) -> str:
public_key_path = Path(f"{private_key_path}.pub")
if not public_key_path.exists():
raise RuntimeError(
f"SSH public key not found at {public_key_path}. Generate it or set VAST_SSH_PRIVATE_KEY."
)
return public_key_path.read_text(encoding="utf-8").strip()
+41
View File
@@ -0,0 +1,41 @@
{
"search": {
"limit": 100,
"order_by": "dlperf",
"order_direction": "desc",
"sort_mode": "dlp_per_dollar",
"offer_type": "ondemand",
"verified_only": true,
"rentable": true,
"rented": false,
"num_gpus": 1,
"min_reliability": 0.98,
"min_cuda_ram_mb": 0,
"min_cpu_ram_mb": 0,
"min_disk_space_gb": 0,
"min_direct_port_count": 1,
"max_dph_total": 0.40,
"gpu_names": ["RTX 3090", "RTX 3090 Ti"],
"countries_exclude": []
},
"instance": {
"image": "vastai/pytorch:latest",
"disk_gb": 48,
"target_state": "running",
"label_prefix": "drl-proj",
"template_hash_id": null
},
"remote": {
"ssh_user": "root",
"workspace_dir": "/workspace/DRL_PROJ",
"remote_data_dir": "/workspace/data/DeepFakeFace",
"remote_output_root": "classifier/outputs",
"ssh_timeout_seconds": 900,
"poll_interval_seconds": 10
},
"transfer": {
"local_output_dir": "classifier/outputs",
"local_data_dir": "data"
},
"keep_on_failure": false
}
+817
View File
@@ -0,0 +1,817 @@
import json
import re
import threading
import time
from dataclasses import asdict, dataclass
from datetime import UTC, datetime
from pathlib import Path
from typing import Any
from pipeline.config import (
PipelineConfig,
read_public_key,
resolve_api_key,
resolve_ssh_private_key,
)
from pipeline.remote import (
RemoteCommandError,
SshTarget,
fetch_directory,
is_process_alive,
read_remote_file,
rsync_file,
rsync_project,
run_detached,
shell_join,
ssh,
ssh_output,
tail_log,
wait_for_ssh,
)
from pipeline.vast_api import VastApiClient, VastApiError, VastInstance
# ── data classes ─────────────────────────────────────────────────────
# Snapshot of a pipeline run written to disk for traceability
@dataclass(slots=True)
class RunManifest:
created_at: str
config_paths: list[str]
instance_id: int | None
offer_id: int | None
ssh_host: str | None
ssh_port: int | None
status: str
remote_workspace: str | None
# CLI flags controlling how the pipeline run behaves
@dataclass
class RunOptions:
download_data: bool = False
send_cropped: bool = False
keep_on_failure: bool = False
dry_run: bool = False
select_offer: bool = False
select_template: bool = False
template_hash: str | None = None
sort_mode: str | None = None
region: str | None = None
price_cap: float | None = None
list_regions: bool = False
use_gpu: bool = True
# ── constants ─────────────────────────────────────────────────────────
# ISO-3166-1-alpha-2 country codes used by the --region europe filter
EUROPE_REGION_CODES = {
"AL", "AD", "AM", "AT", "AZ", "BA", "BE", "BG", "BY", "CH", "CY", "CZ",
"DE", "DK", "EE", "ES", "FI", "FR", "GB", "GE", "GR", "HR", "HU", "IE",
"IS", "IT", "LI", "LT", "LU", "LV", "MC", "MD", "ME", "MK", "MT", "NL",
"NO", "PL", "PT", "RO", "RS", "SE", "SI", "SK", "SM", "TR", "UA", "VA",
}
# How often (in epochs) to rsync generator outputs back during training
_GENERATOR_SYNC_INTERVAL = 50
# Raised when the user aborts interactive offer selection
class OfferSelectionAborted(RuntimeError):
pass
# Raised when the user aborts interactive template selection
class TemplateSelectionAborted(RuntimeError):
pass
# ── runner ────────────────────────────────────────────────────────────
# Orchestrates the full lifecycle: search offers -> rent GPU -> train -> fetch results -> destroy
class EphemeralVastRunner:
# Pre-defined Vast.ai image templates the user can pick interactively
TEMPLATE_CATALOG: list[dict[str, str]] = [
{
"hash_id": "661d064bbda1f2a133816b6d55da07c3",
"name": "PyTorch (cuDNN Devel)",
"image": "pytorch/pytorch:2.5.1-cuda12.4-cudnn9-devel",
},
{
"hash_id": "b9e5a8f3d4c1e7f6a2b0d8c9e5f1a3b7",
"name": "PyTorch (Latest)",
"image": "pytorch/pytorch:latest",
},
{
"hash_id": "none",
"name": "Custom image (no template)",
"image": "",
},
]
# Resolve the project root (two levels up from this file)
def __init__(self, override_path: Path | None) -> None:
self.project_root = Path(__file__).resolve().parent.parent
self.config = PipelineConfig.load(self.project_root, override_path)
self.api = VastApiClient(resolve_api_key())
self.private_key = resolve_ssh_private_key()
self.public_key = read_public_key(self.private_key)
self.exclude_file = self.project_root / "pipeline" / "rsync-excludes.txt"
# ── formatting helpers ────────────────────────────────────────────
# Format seconds into a compact human-readable string (e.g. "1d2h30m")
@staticmethod
def _fmt_duration(seconds: float) -> str:
s = max(int(seconds), 0)
days, s = divmod(s, 86400)
hours, s = divmod(s, 3600)
minutes, _ = divmod(s, 60)
parts: list[str] = []
if days:
parts.append(f"{days}d")
if hours or days:
parts.append(f"{hours}h")
parts.append(f"{minutes}m")
return "".join(parts)
# Estimate cost from wall-clock time and the offer's $/h rate
@staticmethod
def _fmt_cost(elapsed_seconds: float, cost_per_hour: float) -> str:
return f"${(elapsed_seconds / 3600) * cost_per_hour:.2f}"
# ── offer search & selection ──────────────────────────────────────
# Build a sort tuple that ranks offers by the chosen mode (performance / price / dlp_per_dollar)
@staticmethod
def _offer_sort_key(offer: dict, sort_mode: str) -> tuple:
price = float(offer.get("dph_total", float("inf")))
dlperf = float(offer.get("dlperf", float("-inf")))
reliability = float(offer.get("reliability2", 0.0))
if sort_mode == "price":
return (price, -dlperf, -reliability)
if sort_mode == "dlp_per_dollar":
per_dollar = dlperf / price if price > 0 else 0.0
return (-per_dollar, -reliability, price)
return (-dlperf, price, -reliability)
# Match a user-friendly region string (e.g. "europe") against the offer's geolocation code
@staticmethod
def _region_matches(offered: str, requested: str) -> bool:
req = requested.strip().lower()
if not req:
return True
if req == "europe":
code = offered.rsplit(", ", 1)[-1].strip() if ", " in offered else ""
return code in EUROPE_REGION_CODES
return req in offered.lower()
# Query Vast.ai for matching offers, optionally filter by region, then sort
def _search_offers(
self,
sort_mode: str | None = None,
region: str | None = None,
price_cap: float | None = None,
) -> list[dict]:
query = self.config.build_offer_query(price_cap=price_cap)
offers = self.api.search_offers(query)
if not offers:
raise VastApiError("No Vast offers matched the configured filters.")
if region:
offers = [o for o in offers if self._region_matches(o.get("geolocation") or "", region)]
if not offers:
raise VastApiError(f"No offers matched region filter: {region!r}")
mode = sort_mode or self.config.search.get("sort_mode", "performance")
return sorted(offers, key=lambda o: self._offer_sort_key(o, mode))
# Print a summary of available regions and their offer counts
@staticmethod
def _print_regions(offers: list[dict]) -> None:
counts: dict[str, int] = {}
for o in offers:
region = o.get("geolocation") or "Unknown"
counts[region] = counts.get(region, 0) + 1
for region, count in sorted(counts.items(), key=lambda kv: (-kv[1], kv[0])):
print(f"{region}: {count}")
# Render one page of offers with GPU specs, price, and reliability
def _print_offer_page(self, offers: list[dict], *, page: int, page_size: int) -> None:
start = page * page_size
end = min(start + page_size, len(offers))
print(f"Offers {start + 1}{end} of {len(offers)}")
for idx in range(start, end):
o = offers[idx]
vram_gb = o.get("gpu_ram", 0) / 1024
duration = self._fmt_duration(float(o.get("duration") or 0))
print(
f" [{idx + 1}] {o.get('gpu_name')} "
f"gpus={o.get('num_gpus')} "
f"vram={vram_gb:.0f}GB "
f"dlperf={o.get('dlperf', 0):.1f} "
f"$/h={o.get('dph_total', 0):.3f} "
f"reliability={o.get('reliability2', 0):.3f} "
f"avail={duration} "
f"region={o.get('geolocation')} "
f"id={o.get('id')}"
)
# Paginated interactive prompt — user picks an offer or aborts
def _choose_offer(self, offers: list[dict]) -> dict:
page, page_size = 0, 10
last_page = max((len(offers) - 1) // page_size, 0)
while True:
self._print_offer_page(offers, page=page, page_size=page_size)
raw = input("Select offer [number / n / p / q]: ").strip().lower()
if raw == "n":
page = min(page + 1, last_page)
elif raw == "p":
page = max(page - 1, 0)
elif raw == "q":
raise OfferSelectionAborted
elif raw.isdigit():
idx = int(raw) - 1
if 0 <= idx < len(offers):
return offers[idx]
print(" Out of range.")
else:
print(" Invalid input.")
# ── template selection ────────────────────────────────────────────
# Decide which Vast template hash to use: CLI flag, interactive pick, or config default
def _resolve_template(self, opts: RunOptions) -> str | None:
if opts.template_hash:
return opts.template_hash
if opts.select_template:
try:
return self._choose_template()
except (TemplateSelectionAborted, KeyboardInterrupt):
print("Template selection aborted, using default image.")
return None
return self.config.instance.get("template_hash_id")
# Interactive prompt — user picks a Docker image template
def _choose_template(self) -> str | None:
print("Available templates:")
for idx, tpl in enumerate(self.TEMPLATE_CATALOG, 1):
print(f" [{idx}] {tpl['name']} ({tpl['image'] or 'config image'})")
while True:
raw = input("Select template [number / q]: ").strip().lower()
if raw == "q":
raise TemplateSelectionAborted
if raw.isdigit():
idx = int(raw) - 1
if 0 <= idx < len(self.TEMPLATE_CATALOG):
tpl = self.TEMPLATE_CATALOG[idx]
if tpl["hash_id"] == "none":
return None
if tpl["image"]:
self.config.instance["image"] = tpl["image"]
print(f" Selected: {tpl['name']}")
return tpl["hash_id"]
print(" Invalid input.")
# ── instance lifecycle ────────────────────────────────────────────
# Build a unique instance label from a prefix, stem, and UTC timestamp
def _label_for(self, stem: str) -> str:
ts = datetime.now(UTC).strftime("%Y%m%d-%H%M%S")
return f"{self.config.instance['label_prefix']}-{stem}-{ts}"
# Assemble the JSON payload sent to Vast.ai when renting an instance
def _build_instance_payload(self, label: str, *, template_hash_id: str | None = None) -> dict[str, Any]:
inst = self.config.instance
payload: dict[str, Any] = {
"image": inst["image"],
"disk": inst["disk_gb"],
"runtype": "ssh_direc ssh_proxy",
"target_state": inst["target_state"],
"label": label,
}
if template_hash_id:
payload["template_hash_id"] = template_hash_id
return payload
# Poll the API until the instance is running and has SSH credentials
def _wait_for_instance(self, instance_id: int) -> VastInstance:
deadline = time.time() + self.config.remote["ssh_timeout_seconds"]
poll = self.config.remote["poll_interval_seconds"]
while time.time() < deadline:
instance = self.api.show_instance(instance_id)
if instance.actual_status == "running" and instance.ssh_host and instance.ssh_port:
return instance
print(f" instance status: {instance.actual_status or 'pending'}...")
time.sleep(poll)
raise VastApiError(f"Timed out waiting for instance {instance_id} to be SSH-ready.")
# Build an SshTarget from the running instance's connection details
def _build_target(self, instance: VastInstance) -> SshTarget:
return SshTarget(
user=self.config.remote["ssh_user"],
host=instance.ssh_host or instance.public_ipaddr or "",
port=int(instance.ssh_port or 22),
private_key=self.private_key,
)
# ── manifest ──────────────────────────────────────────────────────
# Persist the run manifest as a timestamped JSON file under the output directory
def _write_manifest(self, manifest: RunManifest, local_output_root: Path) -> Path:
d = local_output_root / "pipeline"
d.mkdir(parents=True, exist_ok=True)
ts = manifest.created_at.replace(":", "").replace("-", "")
path = d / f"{ts}.json"
with open(path, "w", encoding="utf-8") as fh:
json.dump(asdict(manifest), fh, indent=2)
return path
# ── remote workspace setup ────────────────────────────────────────
# Upload project files, bootstrap the venv, and ensure training data is present
def _prepare_remote_workspace(
self, target: SshTarget, *, download_data: bool, send_cropped: bool,
first_config: Path | None = None,
) -> None:
ws = self.config.remote["workspace_dir"]
poll = self.config.remote["poll_interval_seconds"]
print(" Syncing project files...")
ssh(target, shell_join(["mkdir", "-p", ws]))
rsync_project(self.project_root, target, ws, self.exclude_file)
print(" Bootstrapping remote environment...")
ssh(target, shell_join(["bash", "-lc", f"cd {ws} && bash pipeline/scripts/bootstrap_env.sh"]))
if send_cropped and first_config is not None:
# Determine which cropped subdirectory to send based on the first config
first_mod, _ = self._module_for_config(first_config)
crop_subdir = "generator" if first_mod.startswith("generator/") else "classifier"
local_zip = self.project_root / f"cropped_{crop_subdir}.zip"
if not local_zip.exists():
print(f" Warning: --send-cropped set but cropped_{crop_subdir}.zip not found.")
print(f" Create it with: zip -r cropped_{crop_subdir}.zip cropped/{crop_subdir}/")
else:
size_mb = local_zip.stat().st_size / (1024 * 1024)
print(f" Sending cropped_{crop_subdir}.zip ({size_mb:.0f} MB)...")
rsync_file(local_zip, target, ws)
print(" Unzipping on remote...")
ssh(target, shell_join(["bash", "-lc", f"cd {ws} && unzip -q -o cropped_{crop_subdir}.zip"]))
print(f" Pre-cropped {crop_subdir} images ready.")
data_ready = self._remote_data_exists(target, f"{ws}/data")
if download_data or not data_ready:
# Download the dataset in a detached process so the SSH timeout doesn't kill it
print(" Downloading dataset from HuggingFace...")
log = f"{ws}/.pipeline_logs/fetch_ds.log"
ssh(target, f"mkdir -p {ws}/.pipeline_logs")
pid = run_detached(
target,
f"cd {ws} && source .venv/bin/activate && python3 classifier/tools/fetch_ds.py",
log,
)
reconnects = 0
# Poll until the fetch process exits, with reconnect logic for SSH drops
while True:
try:
if not is_process_alive(target, pid):
break
reconnects = 0
except RemoteCommandError as exc:
if "Command failed (255)" not in str(exc) or reconnects >= 10:
raise
reconnects += 1
time.sleep(min(5 * reconnects, 30))
continue
time.sleep(poll)
print(" Dataset ready.")
# ── config routing ────────────────────────────────────────────────
# Determine which training script and output dir to use based on config location
def _module_for_config(self, config_path: Path) -> tuple[str, str]:
try:
rel = config_path.resolve().relative_to(self.project_root)
except ValueError:
rel = config_path
if rel.parts[0] == "generator":
return "generator/run.py", "generator/outputs"
return "classifier/run.py", "classifier/outputs"
# Merge two dicts recursively (override wins on leaf keys)
@staticmethod
def _deep_merge_dicts(base: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]:
result = base.copy()
for key, value in override.items():
if key in result and isinstance(result[key], dict) and isinstance(value, dict):
result[key] = EphemeralVastRunner._deep_merge_dicts(result[key], value)
else:
result[key] = value
return result
# Recursively resolve "extends" references in JSON configs with cycle detection
def _load_config_with_extends(self, config_path: Path, seen: set[Path] | None = None) -> dict[str, Any]:
if seen is None:
seen = set()
resolved = config_path.resolve()
if resolved in seen:
raise ValueError(f"Circular config inheritance detected at: {config_path}")
seen.add(resolved)
with open(config_path, encoding="utf-8") as fh:
cfg = json.load(fh)
base_ref = cfg.pop("extends", None)
if not base_ref:
seen.remove(resolved)
return cfg
base_path = (config_path.parent / base_ref).resolve()
base_cfg = self._load_config_with_extends(base_path, seen=seen)
seen.remove(resolved)
return self._deep_merge_dicts(base_cfg, cfg)
# Return a stable signature used to detect duplicate training configs
def _normalized_config_signature(self, config_path: Path) -> str:
run_script, _ = self._module_for_config(config_path)
# Generator configs do not currently use shared/extends inheritance.
if run_script.startswith("generator/"):
with open(config_path, encoding="utf-8") as fh:
cfg = json.load(fh)
else:
cfg = self._load_config_with_extends(config_path)
shared_path = config_path.parent.parent / "shared.json"
if shared_path.exists():
with open(shared_path, encoding="utf-8") as fh:
shared_cfg = json.load(fh)
cfg = self._deep_merge_dicts(shared_cfg, cfg)
# run_name should not influence whether two configs are equivalent to train.
cfg.pop("run_name", None)
return json.dumps(cfg, sort_keys=True, separators=(",", ":"))
# Detect configs that are pure extends (only run_name + extends, no new training settings)
# These are pointers to an already-trained experiment and should be skipped
def _is_pure_extend(self, config_path: Path) -> str | None:
with open(config_path, encoding="utf-8") as fh:
raw = json.load(fh)
if "extends" not in raw:
return None
non_meta = {k for k in raw if k not in ("run_name", "extends")}
if non_meta:
return None
base_path = (config_path.parent / raw["extends"]).resolve()
return str(base_path.relative_to(self.project_root))
def _dedupe_training_configs(self, config_paths: list[Path]) -> list[Path]:
seen: dict[str, Path] = {}
deduped: list[Path] = []
for cp in config_paths:
pure_extends = self._is_pure_extend(cp)
if pure_extends is not None:
print(f"Skipping {cp.name} (pure extend of {pure_extends})")
continue
sig = self._normalized_config_signature(cp)
if sig in seen:
first = seen[sig]
print(f"Skipping duplicate config {cp.name} (same training settings as {first.name})")
continue
seen[sig] = cp
deduped.append(cp)
return deduped
# ── remote directory checks ───────────────────────────────────────
# Check whether a directory exists on the remote host
def _remote_dir_exists(self, target: SshTarget, path: str) -> bool:
out = ssh_output(target, f"if [ -d {path} ]; then echo yes; else echo no; fi").strip()
return out == "yes"
# Check that the dataset is populated (not just an empty directory)
def _remote_data_exists(self, target: SshTarget, data_dir: str) -> bool:
out = ssh_output(target, f"if [ -d {data_dir}/wiki ]; then echo yes; else echo no; fi").strip()
return out == "yes"
# ── training ─────────────────────────────────────────────────────
# Launch the training script inside nohup, stream logs, and handle reconnects
def _run_training(self, target: SshTarget, config_path: Path, opts: RunOptions) -> None:
ws = self.config.remote["workspace_dir"]
config_rel = config_path.resolve().relative_to(self.project_root)
run_script, output_root = self._module_for_config(config_path)
is_generator = run_script.startswith("generator/")
run_cmd = shell_join(["python3", "-u", run_script, str(config_rel), "--output-root", output_root])
if opts.use_gpu:
run_cmd += " --use-gpu"
log_dir = f"{ws}/.pipeline_logs"
log_path = f"{log_dir}/{config_path.stem}.log"
exit_path = f"{log_dir}/{config_path.stem}.exit"
ssh(target, f"mkdir -p {log_dir}")
# Run training under nohup; write the exit code to a marker file when done
pid = run_detached(
target,
f"cd {ws} && source .venv/bin/activate && {run_cmd}; echo $? > {exit_path}",
log_path,
)
print(f" Training started (PID {pid})")
poll = self.config.remote["poll_interval_seconds"]
max_reconnects = 10
reconnects = 0
next_sync = _GENERATOR_SYNC_INTERVAL if is_generator else None
local_output_root = self.project_root / output_root
stop_streaming = threading.Event()
# Background thread: tail the remote log and print tqdm-aware output
def _stream_worker() -> None:
tail_proc = None
last_was_progress = False
try:
time.sleep(1) # wait for the log file to be created
tail_proc = tail_log(target, log_path)
for raw in tail_proc.stdout:
if stop_streaming.is_set():
break
# tqdm uses \r for in-place updates; take only the last segment
line = raw.rstrip("\r\n").rsplit("\r", 1)[-1].rstrip()
if not line:
continue
is_progress = "it/s]" in line or "%|" in line
if is_progress:
print(f"\r {line}", end="", flush=True)
last_was_progress = True
else:
if last_was_progress:
print(flush=True)
last_was_progress = False
print(f" {line}", flush=True)
except Exception:
pass
finally:
if last_was_progress:
print(flush=True)
if tail_proc is not None:
tail_proc.terminate()
tail_proc.wait(timeout=5)
stream_thread = threading.Thread(target=_stream_worker, daemon=True)
stream_thread.start()
try:
# Main polling loop: check if the process is alive, periodically sync generator outputs
while True:
try:
alive = is_process_alive(target, pid)
except RemoteCommandError as exc:
if "Command failed (255)" not in str(exc) or reconnects >= max_reconnects:
raise
reconnects += 1
print(f"\n SSH dropped (reconnect {reconnects}/{max_reconnects}), "
"training continues on remote...")
time.sleep(min(5 * reconnects, 30))
continue
if is_generator and next_sync is not None:
epoch = self._latest_epoch_in_log(target, log_path)
if epoch >= next_sync:
print(f" Syncing generator outputs at epoch {next_sync}...")
self._fetch_outputs(target, output_root, local_output_root)
next_sync += _GENERATOR_SYNC_INTERVAL
if not alive:
break
reconnects = 0
time.sleep(poll)
finally:
stop_streaming.set()
stream_thread.join(timeout=10)
# Read the exit code marker that the nohup wrapper wrote
raw_exit = read_remote_file(target, exit_path)
exit_code = int(raw_exit) if raw_exit is not None else -1
if exit_code != 0:
if self._remote_dir_exists(target, f"{ws}/{output_root}"):
self._fetch_outputs(target, output_root, local_output_root)
log_tail = read_remote_file(target, log_path)
snippet = "\n".join((log_tail or "").splitlines()[-30:])
raise RemoteCommandError(f"Training failed (exit {exit_code}).\n{snippet}")
# Parse the latest epoch number from the training log (used for generator sync scheduling)
@staticmethod
def _latest_epoch_in_log(target: SshTarget, log_path: str) -> int:
try:
out = ssh_output(target, f"tail -n 200 {log_path} 2>/dev/null || true")
epochs = [int(m) for m in re.findall(r"\[(\d+)/\d+\]", out)]
return max(epochs, default=0)
except RemoteCommandError:
return 0
# Download remote outputs with up to 3 retry attempts (exponential back-off)
def _fetch_outputs(self, target: SshTarget, remote_output_root: str, local_output_root: Path) -> None:
remote_path = f"{self.config.remote['workspace_dir']}/{remote_output_root}"
delay = 5
for attempt in range(1, 4):
try:
fetch_directory(target, remote_path, local_output_root)
return
except RemoteCommandError as exc:
if attempt < 3:
print(f" Download attempt {attempt} failed: {exc}")
print(f" Retrying in {delay}s...")
time.sleep(delay)
delay *= 2
else:
raise
# ── public commands ───────────────────────────────────────────────
# Full pipeline: validate configs -> find offer -> rent instance -> train -> fetch -> destroy
def run(self, config_paths: list[Path], opts: RunOptions) -> None:
resolved = []
for cp in config_paths:
cp = (self.project_root / cp).resolve() if not cp.is_absolute() else cp
if not cp.exists():
raise FileNotFoundError(f"Config not found: {cp}")
resolved.append(cp)
resolved = self._dedupe_training_configs(resolved)
# Abort early if all configs were duplicates
if not resolved:
raise ValueError("No unique configs to run after deduplication.")
n = len(resolved)
_, first_output_root = self._module_for_config(resolved[0])
local_output_root = self.project_root / first_output_root
self.api.ensure_ssh_key(self.public_key)
offers = self._search_offers(opts.sort_mode, opts.region, opts.price_cap)
if opts.list_regions:
self._print_regions(offers)
return
try:
offer = self._choose_offer(offers) if opts.select_offer else offers[0]
except (OfferSelectionAborted, KeyboardInterrupt):
print("Aborted.")
return
if opts.dry_run:
print(json.dumps(offer if opts.select_offer else offers[:10], indent=2))
return
label_stem = resolved[0].stem if n == 1 else resolved[0].parent.name
template_hash_id = self._resolve_template(opts)
offer_id = int(offer["id"])
cost_per_hour = float(offer.get("dph_total", 0))
manifest = RunManifest(
created_at=datetime.now(UTC).isoformat(),
config_paths=[str(cp.relative_to(self.project_root)) for cp in resolved],
instance_id=None,
offer_id=offer_id,
ssh_host=None,
ssh_port=None,
status="creating",
remote_workspace=self.config.remote["workspace_dir"],
)
manifest_path = self._write_manifest(manifest, local_output_root)
instance_id: int | None = None
should_destroy = True
start_time = time.time()
try:
# Build the instance payload and rent the GPU
payload = self._build_instance_payload(self._label_for(label_stem),
template_hash_id=template_hash_id)
print(f"Creating instance from offer {offer_id} (${cost_per_hour:.3f}/h)...")
instance_id = self.api.create_instance(offer_id, payload)
self.api.attach_ssh_key(instance_id, self.public_key)
# Record the instance ID in the manifest immediately for debugging
manifest.instance_id = instance_id
# Wait for the instance to be SSH-ready, then update the manifest
print(f"Waiting for instance {instance_id}...")
instance = self._wait_for_instance(instance_id)
manifest.ssh_host = instance.ssh_host
manifest.ssh_port = instance.ssh_port
manifest.status = instance.actual_status
self._write_manifest(manifest, local_output_root)
print(f"Instance ready: {instance.ssh_host}:{instance.ssh_port} "
f"({instance.gpu_name}, ${instance.dph_total}/h)")
target = self._build_target(instance)
# Wait for SSH to become available before running any remote commands
print("Waiting for SSH...")
wait_for_ssh(
target,
timeout_seconds=self.config.remote["ssh_timeout_seconds"],
poll_interval_seconds=self.config.remote["poll_interval_seconds"],
)
self._prepare_remote_workspace(
target, download_data=opts.download_data, send_cropped=opts.send_cropped,
first_config=resolved[0],
)
# Run each training config sequentially on the same instance
for i, config_path in enumerate(resolved, 1):
print(f"\n[{i}/{n}] Training: {config_path.name}")
self._run_training(target, config_path, opts)
print(f"[{i}/{n}] Fetching outputs...")
_, output_root = self._module_for_config(config_path)
self._fetch_outputs(target, output_root, self.project_root / output_root)
elapsed = time.time() - start_time
manifest.status = "completed"
self._write_manifest(manifest, local_output_root)
print(f"\nAll {n} run(s) completed in {self._fmt_duration(elapsed)} "
f"(~{self._fmt_cost(elapsed, cost_per_hour)}). Manifest: {manifest_path}")
except KeyboardInterrupt:
elapsed = time.time() - start_time
manifest.status = "cancelled"
self._write_manifest(manifest, local_output_root)
print(f"\nCancelled after {self._fmt_duration(elapsed)} "
f"(~{self._fmt_cost(elapsed, cost_per_hour)}).")
except Exception:
elapsed = time.time() - start_time
manifest.status = "failed"
self._write_manifest(manifest, local_output_root)
print(f" Failed after {self._fmt_duration(elapsed)} "
f"(~{self._fmt_cost(elapsed, cost_per_hour)}).")
if opts.keep_on_failure or self.config.keep_on_failure:
should_destroy = False
raise
finally:
if instance_id is not None and should_destroy:
elapsed = time.time() - start_time
print(f"Destroying instance {instance_id} "
f"(total: {self._fmt_duration(elapsed)}, "
f"~{self._fmt_cost(elapsed, cost_per_hour)})")
self.api.destroy_instance(instance_id)
# CLI subcommand: search and display available GPU offers without renting
def offers(
self,
*,
sort_mode: str | None,
region: str | None,
price_cap: float | None,
select_offer: bool,
list_regions: bool,
limit_output: int,
) -> None:
offers = self._search_offers(sort_mode, region, price_cap)
if list_regions:
self._print_regions(offers)
return
if select_offer:
try:
offer = self._choose_offer(offers)
except (OfferSelectionAborted, KeyboardInterrupt):
print("Aborted.")
return
print(json.dumps(offer, indent=2))
else:
print(json.dumps(offers[:limit_output], indent=2))
# CLI subcommand: rent an instance and print its connection details (no training)
def up(self, *, label: str | None, opts: RunOptions | None = None) -> None:
if opts is None:
opts = RunOptions()
self.api.ensure_ssh_key(self.public_key)
template_hash_id = self._resolve_template(opts)
offer = self._search_offers()[0]
offer_id = int(offer["id"])
payload = self._build_instance_payload(
label or self._label_for("manual"), template_hash_id=template_hash_id
)
instance_id = self.api.create_instance(offer_id, payload)
self.api.attach_ssh_key(instance_id, self.public_key)
instance = self._wait_for_instance(instance_id)
print(json.dumps({
"offer_id": offer_id,
"instance_id": instance_id,
"ssh_host": instance.ssh_host,
"ssh_port": instance.ssh_port,
"gpu_name": instance.gpu_name,
"dph_total": instance.dph_total,
}, indent=2))
# CLI subcommand: print the raw instance status JSON
def status(self, instance_id: int) -> None:
print(json.dumps(self.api.show_instance(instance_id).raw, indent=2))
# CLI subcommand: destroy a running instance by ID
def down(self, instance_id: int) -> None:
self.api.destroy_instance(instance_id)
print(f"Destroyed instance {instance_id}")
+180
View File
@@ -0,0 +1,180 @@
import shlex
import subprocess
import time
from dataclasses import dataclass
from pathlib import Path
# Error raised when a remote SSH or rsync command fails
class RemoteCommandError(RuntimeError):
pass
# Connection details for an SSH target (Vast instance)
@dataclass(slots=True)
class SshTarget:
user: str
host: str
port: int
private_key: Path
# user@host string used by SSH and rsync
@property
def destination(self) -> str:
return f"{self.user}@{self.host}"
# Base SSH flags: identity, port, keep-alive, and auto-accept host keys
def ssh_base_command(self) -> list[str]:
return [
"ssh",
"-i", str(self.private_key),
"-p", str(self.port),
"-o", "BatchMode=yes",
"-o", "StrictHostKeyChecking=accept-new",
"-o", "ServerAliveInterval=30",
"-o", "ServerAliveCountMax=10",
"-o", "TCPKeepAlive=yes",
self.destination,
]
# SSH flags formatted as a string for rsync's -e option
def rsync_ssh_opts(self) -> str:
return " ".join(self.ssh_base_command()[:-1])
# ── low-level subprocess helpers ─────────────────────────────────────
# Run a local subprocess, raising RemoteCommandError on non-zero exit
def _run(command: list[str]) -> None:
result = subprocess.run(command, check=False, stderr=subprocess.PIPE, text=True)
if result.returncode != 0:
snippet = result.stderr.strip()[-500:] if result.stderr else ""
msg = f"Command failed ({result.returncode}): {' '.join(command)}"
raise RemoteCommandError(f"{msg}\n{snippet}" if snippet else msg)
# ── SSH ───────────────────────────────────────────────────────────────
# Repeatedly attempt an SSH connection until the target accepts it or the deadline passes
def wait_for_ssh(target: SshTarget, *, timeout_seconds: int, poll_interval_seconds: int) -> None:
deadline = time.time() + timeout_seconds
while time.time() < deadline:
result = subprocess.run(
target.ssh_base_command() + ["true"],
check=False, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL,
)
if result.returncode == 0:
return
remaining = int(deadline - time.time())
print(f" SSH not ready yet, retrying... ({remaining}s remaining)")
time.sleep(poll_interval_seconds)
raise RemoteCommandError(
f"Timed out waiting for SSH on {target.host}:{target.port} after {timeout_seconds}s."
)
# Execute a single command on the remote host via SSH
def ssh(target: SshTarget, command: str) -> None:
_run(target.ssh_base_command() + [command])
# Execute a remote command and return its stdout
def ssh_output(target: SshTarget, command: str) -> str:
result = subprocess.run(
target.ssh_base_command() + [command],
check=False, capture_output=True, text=True,
)
if result.returncode != 0:
raise RemoteCommandError(
f"Command failed ({result.returncode}): {command}\n{result.stderr.strip()}"
)
return result.stdout
# ── detached process management ───────────────────────────────────────
# Launch a remote command under nohup so it survives SSH drops; return PID
def run_detached(target: SshTarget, command: str, log_path: str) -> int:
inner = f"nohup bash -c {shlex.quote(command)} > {shlex.quote(log_path)} 2>&1 & echo $!"
output = ssh_output(target, inner)
return int(output.strip().splitlines()[-1])
# Return True if the remote process is running
# Raise RemoteCommandError on SSH failure (exit 255) so callers can apply
# reconnect logic and only return False when kill -0 confirms the PID is gone
def is_process_alive(target: SshTarget, pid: int) -> bool:
try:
ssh(target, f"kill -0 {pid}")
return True
except RemoteCommandError as exc:
if "Command failed (255)" in str(exc):
raise # SSH failure, let caller handle
return False # process not found
# Read a remote file's contents; returns None if the file does not exist
def read_remote_file(target: SshTarget, path: str) -> str | None:
return ssh_output(target, f"cat {shlex.quote(path)} 2>/dev/null || true").strip() or None
# Open a persistent tail -f on a remote file for real-time streaming
def tail_log(target: SshTarget, log_path: str) -> subprocess.Popen:
cmd = target.ssh_base_command() + [f"tail -f -n +1 {shlex.quote(log_path)}"]
return subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL, text=True)
# ── rsync ─────────────────────────────────────────────────────────────
# Upload the entire project directory, respecting the exclude list
def rsync_project(local_root: Path, target: SshTarget, remote_root: str, exclude_file: Path) -> None:
_run([
"rsync", "-az", "--info=progress2", "--delete",
"--exclude-from", str(exclude_file),
"-e", target.rsync_ssh_opts(),
f"{local_root}/",
f"{target.destination}:{remote_root}/",
])
# Upload a single file to a remote directory
def rsync_file(local_file: Path, target: SshTarget, remote_dir: str) -> None:
_run([
"rsync", "-az", "--partial", "--info=progress2",
"-e", target.rsync_ssh_opts(),
str(local_file),
f"{target.destination}:{remote_dir}/",
])
# Download a remote directory, verify file count, and resume partial transfers
def fetch_directory(target: SshTarget, remote_dir: str, local_dir: Path) -> None:
local_dir.mkdir(parents=True, exist_ok=True)
# Count remote files first so we can verify completeness after transfer
remote_count = int(ssh_output(
target, f"find {remote_dir.rstrip('/')} -type f 2>/dev/null | wc -l"
).strip())
_run([
"rsync", "-az", "--partial", "--append-verify", "--info=progress2",
"-e", target.rsync_ssh_opts(),
f"{target.destination}:{remote_dir.rstrip('/')}/",
f"{local_dir}/",
])
local_count = sum(1 for p in local_dir.rglob("*") if p.is_file())
# Raise if fewer files arrived than expected (corrupted / interrupted transfer)
if local_count < remote_count:
raise RemoteCommandError(
f"Download incomplete: got {local_count} files, expected {remote_count}. "
f"Remote: {remote_dir} Local: {local_dir}"
)
print(f" Verified: {local_count} files downloaded successfully")
# ── misc ──────────────────────────────────────────────────────────────
# Quote and join shell arguments for safe remote execution
def shell_join(parts: list[str]) -> str:
return " ".join(shlex.quote(p) for p in parts)
+15
View File
@@ -0,0 +1,15 @@
.claude/
.git/
.venv/
__pycache__/
.ipynb_checkpoints/
/data/
/cropped/
/cropped_classifier.zip
/cropped_generator.zip
/classifier/outputs/
/generator/outputs/
+48
View File
@@ -0,0 +1,48 @@
#!/usr/bin/env bash
set -euo pipefail
ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)"
cd "$ROOT_DIR"
PYTHON_BIN="${PYTHON_BIN:-python3}"
USE_SYSTEM_SITE_PACKAGES="${USE_SYSTEM_SITE_PACKAGES:-1}"
SKIP_TORCH_INSTALL="${SKIP_TORCH_INSTALL:-1}"
VENV_ARGS=()
if [ "$USE_SYSTEM_SITE_PACKAGES" = "1" ]; then
VENV_ARGS+=(--system-site-packages)
fi
"$PYTHON_BIN" -m venv "${VENV_ARGS[@]}" .venv
source .venv/bin/activate
python -m pip install --upgrade pip setuptools wheel
# Capture system torch versions before venv installs override them
TORCH_VERSION="$($PYTHON_BIN -c 'import torch; print(torch.__version__)' 2>/dev/null || true)"
TV_VERSION="$($PYTHON_BIN -c 'import torchvision; print(torchvision.__version__)' 2>/dev/null || true)"
if [ "$SKIP_TORCH_INSTALL" = "1" ]; then
filtered_requirements="$(mktemp)"
grep -Ev '^(torch|torchvision)([<>=!~].*)?$' requirements.txt > "$filtered_requirements"
# Pin system torch/torchvision so transitive deps (e.g. facenet-pytorch) can't downgrade them
constraints_file="$(mktemp)"
[ -n "$TORCH_VERSION" ] && echo "torch==$TORCH_VERSION" >> "$constraints_file"
[ -n "$TV_VERSION" ] && echo "torchvision==$TV_VERSION" >> "$constraints_file"
python -m pip install -c "$constraints_file" -r "$filtered_requirements"
rm -f "$filtered_requirements" "$constraints_file"
else
python -m pip install -r requirements.txt
fi
mkdir -p classifier/outputs/logs classifier/outputs/models classifier/outputs/analysis classifier/outputs/figures classifier/outputs/pipeline
python - <<'PY'
try:
import torch
print(f"torch={torch.__version__} cuda_available={torch.cuda.is_available()}")
except Exception as exc:
print(f"torch check failed: {exc}")
PY
+131
View File
@@ -0,0 +1,131 @@
import json
from dataclasses import dataclass
from typing import Any
from urllib import error, request
# Generic error raised for any Vast.ai API failure
class VastApiError(RuntimeError):
pass
# Lightweight view of a Vast.ai instance with the fields the pipeline cares about
@dataclass(slots=True)
class VastInstance:
id: int
actual_status: str
ssh_host: str | None
ssh_port: int | None
public_ipaddr: str | None
gpu_name: str | None
dph_total: float | None
raw: dict[str, Any]
# Thin wrapper around the Vast.ai REST API
class VastApiClient:
def __init__(self, api_key: str, *, base_url: str = "https://console.vast.ai") -> None:
self.api_key = api_key
self.base_url = base_url.rstrip("/")
# Low-level request helper — sends JSON, returns parsed response body
def _request(self, method: str, path: str, payload: dict[str, Any] | None = None) -> Any:
url = f"{self.base_url}{path}"
data = None
headers = {"Authorization": f"Bearer {self.api_key}"}
if payload is not None:
headers["Content-Type"] = "application/json"
data = json.dumps(payload).encode("utf-8")
req = request.Request(url, method=method, data=data, headers=headers)
try:
with request.urlopen(req, timeout=60) as response:
body = response.read()
# Vast.ai returns varied error formats; surface whatever body we get
except error.HTTPError as exc:
details = exc.read().decode("utf-8", errors="replace")
raise VastApiError(f"{method} {path} failed with {exc.code}: {details}") from exc
except error.URLError as exc:
raise VastApiError(f"{method} {path} failed: {exc.reason}") from exc
if not body:
return None
return json.loads(body)
# Fetch the currently authenticated user's profile
def show_user(self) -> dict[str, Any]:
return self._request("GET", "/api/v0/users/current/")
# ── SSH keys ────────────────────────────────────────────────────────
# List registered SSH keys; handles inconsistent response shapes from the API
def show_ssh_keys(self) -> list[dict[str, Any]]:
response = self._request("GET", "/api/v0/ssh/")
if isinstance(response, list):
return response
if isinstance(response, dict):
for key in ("keys", "ssh_keys"):
value = response.get(key)
if isinstance(value, list):
return value
raise VastApiError(f"Unexpected SSH key response: {response}")
# Register the public key if it isn't already present
def ensure_ssh_key(self, public_key: str) -> None:
existing_keys = self.show_ssh_keys()
if any(
(
item.get("key")
or item.get("public_key")
or item.get("ssh_key")
or ""
).strip() == public_key
for item in existing_keys
):
return
self._request("POST", "/api/v0/ssh/", {"ssh_key": public_key})
# Authorise an SSH key for a running instance
def attach_ssh_key(self, instance_id: int, public_key: str) -> None:
self._request("POST", f"/api/v0/instances/{instance_id}/ssh/", {"ssh_key": public_key})
# ── Offers ─────────────────────────────────────────────────────────
# Search available GPU offers matching a query filter
def search_offers(self, query: dict[str, Any]) -> list[dict[str, Any]]:
response = self._request("POST", "/api/v0/bundles/", query)
offers = response.get("offers", [])
if isinstance(offers, dict):
return [offers]
return offers
# ── Instances ──────────────────────────────────────────────────────
# Rent an offer, returning the new contract (instance) ID
def create_instance(self, offer_id: int, payload: dict[str, Any]) -> int:
response = self._request("PUT", f"/api/v0/asks/{offer_id}/", payload)
if not response or not response.get("success"):
raise VastApiError(f"Instance creation failed for offer {offer_id}: {response}")
return int(response["new_contract"])
# Fetch current status and connection details for an instance
def show_instance(self, instance_id: int) -> VastInstance:
response = self._request("GET", f"/api/v0/instances/{instance_id}/")
raw = response.get("instances")
if not raw:
raise VastApiError(f"No instance details found for {instance_id}: {response}")
return VastInstance(
id=int(raw["id"]),
actual_status=raw.get("actual_status", ""),
ssh_host=raw.get("ssh_host"),
ssh_port=raw.get("ssh_port"),
public_ipaddr=raw.get("public_ipaddr"),
gpu_name=raw.get("gpu_name"),
dph_total=raw.get("dph_total"),
raw=raw,
)
# Permanently destroy an instance (releases the GPU and billing)
def destroy_instance(self, instance_id: int) -> None:
self._request("DELETE", f"/api/v0/instances/{instance_id}/")