Clean state
This commit is contained in:
@@ -0,0 +1,164 @@
|
||||
# Pipeline
|
||||
|
||||
Orchestrates ephemeral Vast.ai GPU instances: searches for an offer, creates the instance, syncs the project, trains, downloads `outputs/`, and destroys the instance automatically. Generator runs also rsync `generator/outputs/` every 50 epochs while training is still running.
|
||||
|
||||
## One-time setup
|
||||
|
||||
Create `pipeline/.env`:
|
||||
|
||||
```dotenv
|
||||
VAST_API_KEY=your-vast-api-key
|
||||
VAST_SSH_PRIVATE_KEY=/home/you/.ssh/id_ed25519 # optional, this is the default
|
||||
```
|
||||
|
||||
The matching `.pub` file must exist alongside the private key. The pipeline registers it with Vast.ai automatically if it isn't there yet.
|
||||
|
||||
## Commands
|
||||
|
||||
### `run` — train on a remote GPU and fetch results
|
||||
|
||||
```
|
||||
python -m pipeline run <config...> [options]
|
||||
```
|
||||
|
||||
Accepts one or more config paths, or a single directory (all `*.json` inside, sorted). Duplicate configs (identical training settings after resolving `extends` and `shared.json`) are skipped automatically.
|
||||
|
||||
| Flag | Default | Description |
|
||||
|------|---------|-------------|
|
||||
| `configs` | *(required)* | One or more config paths, or a directory of JSON configs |
|
||||
| `--download-data` | off | Download the DFF dataset via HuggingFace on the remote before training |
|
||||
| `--send-cropped` | off | Rsync local `cropped/{classifier,generator}/` to remote (picks subdirectory based on config) |
|
||||
| `--select-offer` | off | Interactively browse and pick the GPU offer |
|
||||
| `--sort` | config | Ranking mode: `price`, `performance`, or `dlp_per_dollar` |
|
||||
| `--region TEXT` | any | Filter by region, e.g. `europe`, `Portugal`, `US` |
|
||||
| `--price FLOAT` | config | Max hourly price cap in USD |
|
||||
| `--dry-run` | off | Print matching offers without creating an instance |
|
||||
| `--keep-on-failure` | off | Do not destroy the instance if training fails |
|
||||
| `--no-gpu` | off | Disable GPU training on remote (use CPU instead) |
|
||||
| `--select-template` | off | Interactively choose a Vast.ai Docker template |
|
||||
| `--template HASH` | config | Use a specific template hash ID |
|
||||
| `--pipeline-config PATH` | none | JSON file that overrides `pipeline/defaults/vast.json` |
|
||||
|
||||
**Examples:**
|
||||
|
||||
```bash
|
||||
# Cheapest available RTX 3090 in Europe, download data on remote
|
||||
python -m pipeline run configs/resnet18.json --region europe --download-data
|
||||
|
||||
# Browse offers interactively, sort by price
|
||||
python -m pipeline run configs/resnet18.json --select-offer --sort price
|
||||
|
||||
# Run all configs in a directory sequentially on one instance
|
||||
python -m pipeline run configs/phase2/ --region europe
|
||||
|
||||
# See what offers would be selected without spending money
|
||||
python -m pipeline run configs/resnet18.json --dry-run --region europe
|
||||
|
||||
# Keep the instance alive if something goes wrong (for debugging)
|
||||
python -m pipeline run configs/resnet18.json --keep-on-failure
|
||||
|
||||
# Cap price at $0.12/h
|
||||
python -m pipeline run configs/resnet18.json --price 0.12
|
||||
```
|
||||
|
||||
### `offers` — inspect available GPU offers
|
||||
|
||||
```
|
||||
python -m pipeline offers [options]
|
||||
```
|
||||
|
||||
| Flag | Default | Description |
|
||||
|------|---------|-------------|
|
||||
| `--sort` | config | Ranking mode: `price`, `performance`, or `dlp_per_dollar` |
|
||||
| `--region TEXT` | any | Region filter |
|
||||
| `--price FLOAT` | config | Max hourly price cap |
|
||||
| `--select-offer` | off | Interactive offer picker (prints the selected offer as JSON) |
|
||||
| `--list-regions` | off | Print a count of available offers per region and exit |
|
||||
| `--limit-output INT` | 10 | How many offers to print |
|
||||
| `--pipeline-config PATH` | none | Pipeline config override |
|
||||
|
||||
**Examples:**
|
||||
|
||||
```bash
|
||||
# See the 20 best-value offers under $0.15/h in Europe
|
||||
python -m pipeline offers --region europe --price 0.15 --limit-output 20
|
||||
|
||||
# List which regions have matching GPUs
|
||||
python -m pipeline offers --list-regions
|
||||
|
||||
# Interactive picker — useful before committing to a run
|
||||
python -m pipeline offers --select-offer --sort price
|
||||
```
|
||||
|
||||
### `up` — create an instance without training
|
||||
|
||||
Spins up an instance and prints SSH connection details. Useful for manual experiments or debugging.
|
||||
|
||||
```
|
||||
python -m pipeline up [options]
|
||||
```
|
||||
|
||||
| Flag | Default | Description |
|
||||
|------|---------|-------------|
|
||||
| `--label TEXT` | auto | Optional label for the instance |
|
||||
| `--select-template` | off | Interactively choose a Vast.ai Docker template |
|
||||
| `--template HASH` | config | Use a specific template hash ID |
|
||||
| `--pipeline-config PATH` | none | Pipeline config override |
|
||||
|
||||
```bash
|
||||
python -m pipeline up
|
||||
python -m pipeline up --label my-debug-session
|
||||
```
|
||||
|
||||
### `status` — show instance details
|
||||
|
||||
```
|
||||
python -m pipeline status <instance_id> [--pipeline-config PATH]
|
||||
```
|
||||
|
||||
### `down` — destroy an instance
|
||||
|
||||
```
|
||||
python -m pipeline down <instance_id> [--pipeline-config PATH]
|
||||
```
|
||||
|
||||
## Pipeline config overrides
|
||||
|
||||
Pass `--pipeline-config my_overrides.json` to override any field from `pipeline/defaults/vast.json`. Only the fields you specify are changed; the rest keep their defaults (deep-merged). Useful for switching GPU types or raising the price cap for a single run without editing defaults.
|
||||
|
||||
**Example — allow RTX 4090, higher price cap:**
|
||||
|
||||
```json
|
||||
{
|
||||
"search": {
|
||||
"gpu_names": ["RTX 4090"],
|
||||
"max_dph_total": 0.45
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Key fields in `pipeline/defaults/vast.json`:**
|
||||
|
||||
| Section | Key | Default | Meaning |
|
||||
|---------|-----|---------|---------|
|
||||
| `search` | `gpu_names` | `["RTX 3090", "RTX 3090 Ti"]` | Accepted GPU models |
|
||||
| `search` | `max_dph_total` | `0.40` | Max price per hour |
|
||||
| `search` | `sort_mode` | `"dlp_per_dollar"` | Default ranking (`price`, `performance`, or `dlp_per_dollar`) |
|
||||
| `search` | `min_reliability` | `0.98` | Minimum host reliability score |
|
||||
| `instance` | `disk_gb` | `48` | Disk size provisioned on the instance |
|
||||
| `instance` | `image` | `"vastai/pytorch:latest"` | Docker image |
|
||||
| `remote` | `workspace_dir` | `"/workspace/DRL_PROJ"` | Remote working directory |
|
||||
| `remote` | `ssh_timeout_seconds` | `900` | How long to wait for SSH to become available |
|
||||
|
||||
## Full workflow example
|
||||
|
||||
```bash
|
||||
# 1. Check what's available and how much it costs
|
||||
python -m pipeline offers --region europe --list-regions
|
||||
python -m pipeline offers --region europe --sort price --limit-output 20
|
||||
|
||||
# 2. Run training (auto-selects best offer, downloads data if needed)
|
||||
python -m pipeline run configs/resnet18.json --region europe --download-data
|
||||
|
||||
# 3. Results land in classifier/outputs/ automatically
|
||||
```
|
||||
@@ -0,0 +1,2 @@
|
||||
"""Ephemeral Vast.ai training pipeline."""
|
||||
|
||||
@@ -0,0 +1,6 @@
|
||||
from pipeline.cli import main
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
|
||||
+130
@@ -0,0 +1,130 @@
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
from pipeline.orchestrator import EphemeralVastRunner, RunOptions
|
||||
|
||||
|
||||
# Accept one or more config files, or a single directory (all *.json inside, sorted)
|
||||
def _resolve_configs(raw: list[str]) -> list[Path]:
|
||||
if len(raw) == 1 and Path(raw[0]).is_dir():
|
||||
configs = sorted(Path(raw[0]).glob("*.json"))
|
||||
if not configs:
|
||||
raise ValueError(f"No JSON configs found in directory: {raw[0]}")
|
||||
return configs
|
||||
return [Path(p) for p in raw]
|
||||
|
||||
|
||||
# Build the argparse CLI with subcommands: offers, run, up, status, down
|
||||
def build_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(description="Ephemeral Vast.ai training pipeline.")
|
||||
subparsers = parser.add_subparsers(dest="command", required=True)
|
||||
|
||||
# Shared offer-related flags used by both "offers" and "run" subcommands
|
||||
def add_offer_options(p: argparse.ArgumentParser) -> None:
|
||||
p.add_argument("--pipeline-config", default=None, help="Optional JSON file with pipeline overrides.")
|
||||
p.add_argument("--sort", choices=["price", "performance", "dlp_per_dollar"], default=None, help="Override offer ranking mode.")
|
||||
p.add_argument("--region", default=None, help="Filter offers by region (e.g. 'europe', 'Portugal').")
|
||||
p.add_argument("--price", type=float, default=None, help="Override max hourly price cap in USD.")
|
||||
p.add_argument("--select-offer", action="store_true", help="Interactively choose an offer.")
|
||||
|
||||
# Shared template-related flags used by "run" and "up" subcommands
|
||||
def add_template_options(p: argparse.ArgumentParser) -> None:
|
||||
p.add_argument("--select-template", action="store_true", help="Interactively choose a Vast.ai template.")
|
||||
p.add_argument("--template", default=None, help="Template hash ID to use for instance creation.")
|
||||
|
||||
# ── offers ──────────────────────────────────────────────────────────
|
||||
|
||||
offers_parser = subparsers.add_parser("offers", help="Inspect available Vast offers.")
|
||||
add_offer_options(offers_parser)
|
||||
offers_parser.add_argument("--list-regions", action="store_true", help="List matching regions and exit.")
|
||||
offers_parser.add_argument("--limit-output", type=int, default=10, help="How many offers to print (default: 10).")
|
||||
|
||||
# ── run ─────────────────────────────────────────────────────────────
|
||||
|
||||
run_parser = subparsers.add_parser("run", help="Create instance, train, fetch outputs, destroy it.")
|
||||
run_parser.add_argument("configs", nargs="+", help="One or more config paths, or a directory of JSON configs.")
|
||||
add_offer_options(run_parser)
|
||||
add_template_options(run_parser)
|
||||
run_parser.add_argument("--download-data", action="store_true", help="Force download dataset via HuggingFace on the remote before training (auto-downloads if missing).")
|
||||
run_parser.add_argument("--send-cropped", action="store_true", help="Rsync local cropped/ subdirectory to remote before training (sends only classifier or generator based on config).")
|
||||
run_parser.add_argument("--keep-on-failure", action="store_true", help="Keep instance on failure.")
|
||||
run_parser.add_argument("--dry-run", action="store_true", help="Search and print offers without creating instance.")
|
||||
run_parser.add_argument("--no-gpu", action="store_true", help="Disable GPU training on remote (use CPU instead).")
|
||||
|
||||
# ── up ──────────────────────────────────────────────────────────────
|
||||
|
||||
up_parser = subparsers.add_parser("up", help="Create instance and print SSH details.")
|
||||
up_parser.add_argument("--pipeline-config", default=None, help="Optional JSON file with pipeline overrides.")
|
||||
add_template_options(up_parser)
|
||||
up_parser.add_argument("--label", default=None, help="Optional label for the instance.")
|
||||
|
||||
# ── status ──────────────────────────────────────────────────────────
|
||||
|
||||
status_parser = subparsers.add_parser("status", help="Show instance details.")
|
||||
status_parser.add_argument("instance_id", type=int)
|
||||
status_parser.add_argument("--pipeline-config", default=None, help="Optional JSON file with pipeline overrides.")
|
||||
|
||||
# ── down ────────────────────────────────────────────────────────────
|
||||
|
||||
down_parser = subparsers.add_parser("down", help="Destroy an instance.")
|
||||
down_parser.add_argument("instance_id", type=int)
|
||||
down_parser.add_argument("--pipeline-config", default=None, help="Optional JSON file with pipeline overrides.")
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
# Parse CLI args and dispatch to the appropriate runner method
|
||||
def main(argv=None) -> int:
|
||||
parser = build_parser()
|
||||
args = parser.parse_args(argv)
|
||||
override = Path(args.pipeline_config) if getattr(args, "pipeline_config", None) else None
|
||||
runner = EphemeralVastRunner(override)
|
||||
|
||||
# ── dispatch ────────────────────────────────────────────────────────
|
||||
|
||||
if args.command == "offers":
|
||||
runner.offers(
|
||||
sort_mode=args.sort,
|
||||
region=args.region,
|
||||
price_cap=args.price,
|
||||
select_offer=args.select_offer,
|
||||
list_regions=args.list_regions,
|
||||
limit_output=args.limit_output,
|
||||
)
|
||||
return 0
|
||||
|
||||
if args.command == "run":
|
||||
opts = RunOptions(
|
||||
download_data=args.download_data,
|
||||
send_cropped=args.send_cropped,
|
||||
keep_on_failure=args.keep_on_failure,
|
||||
dry_run=args.dry_run,
|
||||
select_offer=args.select_offer,
|
||||
select_template=args.select_template,
|
||||
template_hash=args.template,
|
||||
sort_mode=args.sort,
|
||||
region=args.region,
|
||||
price_cap=args.price,
|
||||
use_gpu=not args.no_gpu, # Default to True, disable with --no-gpu
|
||||
)
|
||||
runner.run(_resolve_configs(args.configs), opts)
|
||||
return 0
|
||||
|
||||
if args.command == "up":
|
||||
up_opts = RunOptions(
|
||||
select_template=getattr(args, "select_template", False),
|
||||
template_hash=getattr(args, "template", None),
|
||||
)
|
||||
runner.up(label=args.label, opts=up_opts)
|
||||
return 0
|
||||
|
||||
if args.command == "status":
|
||||
runner.status(args.instance_id)
|
||||
return 0
|
||||
|
||||
if args.command == "down":
|
||||
runner.down(args.instance_id)
|
||||
return 0
|
||||
|
||||
parser.error(f"Unknown command: {args.command}")
|
||||
return 2
|
||||
@@ -0,0 +1,182 @@
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
# Load environment variables from pipeline/.env if present
|
||||
# This file is intentionally optional and gitignored
|
||||
# Expected keys:
|
||||
# VAST_API_KEY=<your-api-key>
|
||||
# VAST_SSH_PRIVATE_KEY=/home/your-user/.ssh/id_ed25519
|
||||
def load_dotenv(dotenv_path: Path) -> None:
|
||||
if not dotenv_path.exists():
|
||||
return
|
||||
for raw_line in dotenv_path.read_text(encoding="utf-8").splitlines():
|
||||
line = raw_line.strip()
|
||||
if not line or line.startswith("#"):
|
||||
continue
|
||||
if "=" not in line:
|
||||
continue
|
||||
key, value = line.split("=", 1)
|
||||
key = key.strip()
|
||||
value = value.strip().strip("'").strip('"')
|
||||
os.environ.setdefault(key, value)
|
||||
|
||||
# Default offer search settings
|
||||
DEFAULT_SEARCH: dict[str, Any] = {
|
||||
"limit": 100,
|
||||
"order_by": "dlperf",
|
||||
"order_direction": "desc",
|
||||
"sort_mode": "performance",
|
||||
"offer_type": "ondemand",
|
||||
"verified_only": True,
|
||||
"rentable": True,
|
||||
"rented": False,
|
||||
"num_gpus": 1,
|
||||
"min_reliability": 0.98,
|
||||
"min_cuda_ram_mb": 0,
|
||||
"min_cpu_ram_mb": 0,
|
||||
"min_disk_space_gb": 0,
|
||||
"min_direct_port_count": 1,
|
||||
"max_dph_total": 0.20,
|
||||
"gpu_names": ["RTX 3090", "RTX 3090 Ti"],
|
||||
"countries_exclude": [],
|
||||
}
|
||||
|
||||
# Default instance settings
|
||||
DEFAULT_INSTANCE: dict[str, Any] = {
|
||||
"image": "vastai/pytorch:latest",
|
||||
"disk_gb": 48,
|
||||
"target_state": "running",
|
||||
"label_prefix": "drl-proj",
|
||||
"template_hash_id": None,
|
||||
}
|
||||
|
||||
# Default remote SSH/workspace settings
|
||||
DEFAULT_REMOTE: dict[str, Any] = {
|
||||
"ssh_user": "root",
|
||||
"workspace_dir": "/workspace/DRL_PROJ",
|
||||
"remote_data_dir": "/workspace/data/DeepFakeFace",
|
||||
"remote_output_root": "classifier/outputs",
|
||||
"ssh_timeout_seconds": 900,
|
||||
"poll_interval_seconds": 10,
|
||||
}
|
||||
|
||||
# Default local transfer paths
|
||||
DEFAULT_TRANSFER: dict[str, Any] = {
|
||||
"local_output_dir": "classifier/outputs",
|
||||
"local_data_dir": "data",
|
||||
}
|
||||
|
||||
|
||||
# Merge nested dictionaries recursively
|
||||
# Used when a user override file should replace only specific nested keys
|
||||
# (for example, overriding just search.max_dph_total without redefining search)
|
||||
def _deep_merge(base: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]:
|
||||
result = dict(base)
|
||||
for key, value in override.items():
|
||||
if isinstance(value, dict) and isinstance(result.get(key), dict):
|
||||
result[key] = _deep_merge(result[key], value)
|
||||
else:
|
||||
result[key] = value
|
||||
return result
|
||||
|
||||
|
||||
# Runtime pipeline config container and loader
|
||||
@dataclass(slots=True)
|
||||
class PipelineConfig:
|
||||
search: dict[str, Any] = field(default_factory=lambda: dict(DEFAULT_SEARCH))
|
||||
instance: dict[str, Any] = field(default_factory=lambda: dict(DEFAULT_INSTANCE))
|
||||
remote: dict[str, Any] = field(default_factory=lambda: dict(DEFAULT_REMOTE))
|
||||
transfer: dict[str, Any] = field(default_factory=lambda: dict(DEFAULT_TRANSFER))
|
||||
keep_on_failure: bool = False
|
||||
|
||||
# Load defaults, apply overrides, and return a PipelineConfig instance
|
||||
@classmethod
|
||||
def load(cls, project_root: Path, override_path: Path | None = None) -> "PipelineConfig":
|
||||
load_dotenv(project_root / "pipeline" / ".env")
|
||||
|
||||
# Load default settings from vast.json
|
||||
defaults_path = project_root / "pipeline" / "defaults" / "vast.json"
|
||||
with open(defaults_path, encoding="utf-8") as handle:
|
||||
raw = json.load(handle)
|
||||
|
||||
# Apply user-provided override config (deep merge keeps unspecified defaults)
|
||||
if override_path is not None:
|
||||
with open(override_path, encoding="utf-8") as handle:
|
||||
raw = _deep_merge(raw, json.load(handle))
|
||||
return cls(
|
||||
search=raw.get("search", {}),
|
||||
instance=raw.get("instance", {}),
|
||||
remote=raw.get("remote", {}),
|
||||
transfer=raw.get("transfer", {}),
|
||||
keep_on_failure=raw.get("keep_on_failure", False),
|
||||
)
|
||||
|
||||
# Convert the internal "search" config shape into Vast.ai query format
|
||||
def build_offer_query(self, *, price_cap: float | None = None) -> dict[str, Any]:
|
||||
s = self.search
|
||||
query: dict[str, Any] = {
|
||||
"limit": s["limit"],
|
||||
"order": [[s["order_by"], s["order_direction"]]],
|
||||
"type": s["offer_type"],
|
||||
"verified": {"eq": s["verified_only"]},
|
||||
"rentable": {"eq": s["rentable"]},
|
||||
"rented": {"eq": s["rented"]},
|
||||
"num_gpus": {"gte": s["num_gpus"]},
|
||||
"reliability2": {"gte": s["min_reliability"]},
|
||||
"gpu_ram": {"gte": s["min_cuda_ram_mb"]},
|
||||
"cpu_ram": {"gte": s["min_cpu_ram_mb"]},
|
||||
"disk_space": {"gte": s["min_disk_space_gb"]},
|
||||
"direct_port_count": {"gte": s["min_direct_port_count"]},
|
||||
}
|
||||
# CLI-provided price cap takes precedence over config default
|
||||
max_price = price_cap or s.get("max_dph_total")
|
||||
if max_price is not None:
|
||||
query["dph_total"] = {"lte": max_price}
|
||||
gpu_names = s.get("gpu_names", [])
|
||||
if gpu_names:
|
||||
query["gpu_name"] = {"in": gpu_names}
|
||||
excluded = s.get("countries_exclude", [])
|
||||
if excluded:
|
||||
query["geolocation"] = {"notin": excluded}
|
||||
return query
|
||||
|
||||
# Resolve the local directory where training outputs are saved
|
||||
def local_output_path(self, project_root: Path) -> Path:
|
||||
return (project_root / self.transfer["local_output_dir"]).resolve()
|
||||
|
||||
# Resolve the local directory where datasets are stored
|
||||
def local_data_path(self, project_root: Path) -> Path:
|
||||
return (project_root / self.transfer["local_data_dir"]).resolve()
|
||||
|
||||
|
||||
# Read the Vast.ai API key from the environment, failing fast if unset
|
||||
def resolve_api_key() -> str:
|
||||
api_key = os.environ.get("VAST_API_KEY")
|
||||
if not api_key:
|
||||
raise RuntimeError("VAST_API_KEY is not set.")
|
||||
return api_key
|
||||
|
||||
|
||||
# Resolve the local SSH private key path, defaulting to ~/.ssh/id_ed25519
|
||||
def resolve_ssh_private_key() -> Path:
|
||||
raw = os.environ.get("VAST_SSH_PRIVATE_KEY", "~/.ssh/id_ed25519")
|
||||
path = Path(raw).expanduser()
|
||||
if not path.exists():
|
||||
raise RuntimeError(
|
||||
f"SSH private key not found at {path}. Set VAST_SSH_PRIVATE_KEY to the correct path."
|
||||
)
|
||||
return path
|
||||
|
||||
|
||||
# Read the corresponding .pub file for a given private key
|
||||
def read_public_key(private_key_path: Path) -> str:
|
||||
public_key_path = Path(f"{private_key_path}.pub")
|
||||
if not public_key_path.exists():
|
||||
raise RuntimeError(
|
||||
f"SSH public key not found at {public_key_path}. Generate it or set VAST_SSH_PRIVATE_KEY."
|
||||
)
|
||||
return public_key_path.read_text(encoding="utf-8").strip()
|
||||
@@ -0,0 +1,41 @@
|
||||
{
|
||||
"search": {
|
||||
"limit": 100,
|
||||
"order_by": "dlperf",
|
||||
"order_direction": "desc",
|
||||
"sort_mode": "dlp_per_dollar",
|
||||
"offer_type": "ondemand",
|
||||
"verified_only": true,
|
||||
"rentable": true,
|
||||
"rented": false,
|
||||
"num_gpus": 1,
|
||||
"min_reliability": 0.98,
|
||||
"min_cuda_ram_mb": 0,
|
||||
"min_cpu_ram_mb": 0,
|
||||
"min_disk_space_gb": 0,
|
||||
"min_direct_port_count": 1,
|
||||
"max_dph_total": 0.40,
|
||||
"gpu_names": ["RTX 3090", "RTX 3090 Ti"],
|
||||
"countries_exclude": []
|
||||
},
|
||||
"instance": {
|
||||
"image": "vastai/pytorch:latest",
|
||||
"disk_gb": 48,
|
||||
"target_state": "running",
|
||||
"label_prefix": "drl-proj",
|
||||
"template_hash_id": null
|
||||
},
|
||||
"remote": {
|
||||
"ssh_user": "root",
|
||||
"workspace_dir": "/workspace/DRL_PROJ",
|
||||
"remote_data_dir": "/workspace/data/DeepFakeFace",
|
||||
"remote_output_root": "classifier/outputs",
|
||||
"ssh_timeout_seconds": 900,
|
||||
"poll_interval_seconds": 10
|
||||
},
|
||||
"transfer": {
|
||||
"local_output_dir": "classifier/outputs",
|
||||
"local_data_dir": "data"
|
||||
},
|
||||
"keep_on_failure": false
|
||||
}
|
||||
@@ -0,0 +1,817 @@
|
||||
import json
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import asdict, dataclass
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from pipeline.config import (
|
||||
PipelineConfig,
|
||||
read_public_key,
|
||||
resolve_api_key,
|
||||
resolve_ssh_private_key,
|
||||
)
|
||||
from pipeline.remote import (
|
||||
RemoteCommandError,
|
||||
SshTarget,
|
||||
fetch_directory,
|
||||
is_process_alive,
|
||||
read_remote_file,
|
||||
rsync_file,
|
||||
rsync_project,
|
||||
run_detached,
|
||||
shell_join,
|
||||
ssh,
|
||||
ssh_output,
|
||||
tail_log,
|
||||
wait_for_ssh,
|
||||
)
|
||||
from pipeline.vast_api import VastApiClient, VastApiError, VastInstance
|
||||
|
||||
|
||||
# ── data classes ─────────────────────────────────────────────────────
|
||||
|
||||
# Snapshot of a pipeline run written to disk for traceability
|
||||
@dataclass(slots=True)
|
||||
class RunManifest:
|
||||
created_at: str
|
||||
config_paths: list[str]
|
||||
instance_id: int | None
|
||||
offer_id: int | None
|
||||
ssh_host: str | None
|
||||
ssh_port: int | None
|
||||
status: str
|
||||
remote_workspace: str | None
|
||||
|
||||
|
||||
# CLI flags controlling how the pipeline run behaves
|
||||
@dataclass
|
||||
class RunOptions:
|
||||
download_data: bool = False
|
||||
send_cropped: bool = False
|
||||
keep_on_failure: bool = False
|
||||
dry_run: bool = False
|
||||
select_offer: bool = False
|
||||
select_template: bool = False
|
||||
template_hash: str | None = None
|
||||
sort_mode: str | None = None
|
||||
region: str | None = None
|
||||
price_cap: float | None = None
|
||||
list_regions: bool = False
|
||||
use_gpu: bool = True
|
||||
|
||||
|
||||
# ── constants ─────────────────────────────────────────────────────────
|
||||
|
||||
# ISO-3166-1-alpha-2 country codes used by the --region europe filter
|
||||
EUROPE_REGION_CODES = {
|
||||
"AL", "AD", "AM", "AT", "AZ", "BA", "BE", "BG", "BY", "CH", "CY", "CZ",
|
||||
"DE", "DK", "EE", "ES", "FI", "FR", "GB", "GE", "GR", "HR", "HU", "IE",
|
||||
"IS", "IT", "LI", "LT", "LU", "LV", "MC", "MD", "ME", "MK", "MT", "NL",
|
||||
"NO", "PL", "PT", "RO", "RS", "SE", "SI", "SK", "SM", "TR", "UA", "VA",
|
||||
}
|
||||
|
||||
# How often (in epochs) to rsync generator outputs back during training
|
||||
_GENERATOR_SYNC_INTERVAL = 50
|
||||
|
||||
|
||||
# Raised when the user aborts interactive offer selection
|
||||
class OfferSelectionAborted(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
# Raised when the user aborts interactive template selection
|
||||
class TemplateSelectionAborted(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
# ── runner ────────────────────────────────────────────────────────────
|
||||
|
||||
# Orchestrates the full lifecycle: search offers -> rent GPU -> train -> fetch results -> destroy
|
||||
class EphemeralVastRunner:
|
||||
|
||||
# Pre-defined Vast.ai image templates the user can pick interactively
|
||||
TEMPLATE_CATALOG: list[dict[str, str]] = [
|
||||
{
|
||||
"hash_id": "661d064bbda1f2a133816b6d55da07c3",
|
||||
"name": "PyTorch (cuDNN Devel)",
|
||||
"image": "pytorch/pytorch:2.5.1-cuda12.4-cudnn9-devel",
|
||||
},
|
||||
{
|
||||
"hash_id": "b9e5a8f3d4c1e7f6a2b0d8c9e5f1a3b7",
|
||||
"name": "PyTorch (Latest)",
|
||||
"image": "pytorch/pytorch:latest",
|
||||
},
|
||||
{
|
||||
"hash_id": "none",
|
||||
"name": "Custom image (no template)",
|
||||
"image": "",
|
||||
},
|
||||
]
|
||||
|
||||
# Resolve the project root (two levels up from this file)
|
||||
def __init__(self, override_path: Path | None) -> None:
|
||||
self.project_root = Path(__file__).resolve().parent.parent
|
||||
self.config = PipelineConfig.load(self.project_root, override_path)
|
||||
self.api = VastApiClient(resolve_api_key())
|
||||
self.private_key = resolve_ssh_private_key()
|
||||
self.public_key = read_public_key(self.private_key)
|
||||
self.exclude_file = self.project_root / "pipeline" / "rsync-excludes.txt"
|
||||
|
||||
# ── formatting helpers ────────────────────────────────────────────
|
||||
|
||||
# Format seconds into a compact human-readable string (e.g. "1d2h30m")
|
||||
@staticmethod
|
||||
def _fmt_duration(seconds: float) -> str:
|
||||
s = max(int(seconds), 0)
|
||||
days, s = divmod(s, 86400)
|
||||
hours, s = divmod(s, 3600)
|
||||
minutes, _ = divmod(s, 60)
|
||||
parts: list[str] = []
|
||||
if days:
|
||||
parts.append(f"{days}d")
|
||||
if hours or days:
|
||||
parts.append(f"{hours}h")
|
||||
parts.append(f"{minutes}m")
|
||||
return "".join(parts)
|
||||
|
||||
# Estimate cost from wall-clock time and the offer's $/h rate
|
||||
@staticmethod
|
||||
def _fmt_cost(elapsed_seconds: float, cost_per_hour: float) -> str:
|
||||
return f"${(elapsed_seconds / 3600) * cost_per_hour:.2f}"
|
||||
|
||||
# ── offer search & selection ──────────────────────────────────────
|
||||
|
||||
# Build a sort tuple that ranks offers by the chosen mode (performance / price / dlp_per_dollar)
|
||||
@staticmethod
|
||||
def _offer_sort_key(offer: dict, sort_mode: str) -> tuple:
|
||||
price = float(offer.get("dph_total", float("inf")))
|
||||
dlperf = float(offer.get("dlperf", float("-inf")))
|
||||
reliability = float(offer.get("reliability2", 0.0))
|
||||
if sort_mode == "price":
|
||||
return (price, -dlperf, -reliability)
|
||||
if sort_mode == "dlp_per_dollar":
|
||||
per_dollar = dlperf / price if price > 0 else 0.0
|
||||
return (-per_dollar, -reliability, price)
|
||||
return (-dlperf, price, -reliability)
|
||||
|
||||
# Match a user-friendly region string (e.g. "europe") against the offer's geolocation code
|
||||
@staticmethod
|
||||
def _region_matches(offered: str, requested: str) -> bool:
|
||||
req = requested.strip().lower()
|
||||
if not req:
|
||||
return True
|
||||
if req == "europe":
|
||||
code = offered.rsplit(", ", 1)[-1].strip() if ", " in offered else ""
|
||||
return code in EUROPE_REGION_CODES
|
||||
return req in offered.lower()
|
||||
|
||||
# Query Vast.ai for matching offers, optionally filter by region, then sort
|
||||
def _search_offers(
|
||||
self,
|
||||
sort_mode: str | None = None,
|
||||
region: str | None = None,
|
||||
price_cap: float | None = None,
|
||||
) -> list[dict]:
|
||||
query = self.config.build_offer_query(price_cap=price_cap)
|
||||
offers = self.api.search_offers(query)
|
||||
if not offers:
|
||||
raise VastApiError("No Vast offers matched the configured filters.")
|
||||
if region:
|
||||
offers = [o for o in offers if self._region_matches(o.get("geolocation") or "", region)]
|
||||
if not offers:
|
||||
raise VastApiError(f"No offers matched region filter: {region!r}")
|
||||
mode = sort_mode or self.config.search.get("sort_mode", "performance")
|
||||
return sorted(offers, key=lambda o: self._offer_sort_key(o, mode))
|
||||
|
||||
# Print a summary of available regions and their offer counts
|
||||
@staticmethod
|
||||
def _print_regions(offers: list[dict]) -> None:
|
||||
counts: dict[str, int] = {}
|
||||
for o in offers:
|
||||
region = o.get("geolocation") or "Unknown"
|
||||
counts[region] = counts.get(region, 0) + 1
|
||||
for region, count in sorted(counts.items(), key=lambda kv: (-kv[1], kv[0])):
|
||||
print(f"{region}: {count}")
|
||||
|
||||
# Render one page of offers with GPU specs, price, and reliability
|
||||
def _print_offer_page(self, offers: list[dict], *, page: int, page_size: int) -> None:
|
||||
start = page * page_size
|
||||
end = min(start + page_size, len(offers))
|
||||
print(f"Offers {start + 1}–{end} of {len(offers)}")
|
||||
for idx in range(start, end):
|
||||
o = offers[idx]
|
||||
vram_gb = o.get("gpu_ram", 0) / 1024
|
||||
duration = self._fmt_duration(float(o.get("duration") or 0))
|
||||
print(
|
||||
f" [{idx + 1}] {o.get('gpu_name')} "
|
||||
f"gpus={o.get('num_gpus')} "
|
||||
f"vram={vram_gb:.0f}GB "
|
||||
f"dlperf={o.get('dlperf', 0):.1f} "
|
||||
f"$/h={o.get('dph_total', 0):.3f} "
|
||||
f"reliability={o.get('reliability2', 0):.3f} "
|
||||
f"avail={duration} "
|
||||
f"region={o.get('geolocation')} "
|
||||
f"id={o.get('id')}"
|
||||
)
|
||||
|
||||
# Paginated interactive prompt — user picks an offer or aborts
|
||||
def _choose_offer(self, offers: list[dict]) -> dict:
|
||||
page, page_size = 0, 10
|
||||
last_page = max((len(offers) - 1) // page_size, 0)
|
||||
while True:
|
||||
self._print_offer_page(offers, page=page, page_size=page_size)
|
||||
raw = input("Select offer [number / n / p / q]: ").strip().lower()
|
||||
if raw == "n":
|
||||
page = min(page + 1, last_page)
|
||||
elif raw == "p":
|
||||
page = max(page - 1, 0)
|
||||
elif raw == "q":
|
||||
raise OfferSelectionAborted
|
||||
elif raw.isdigit():
|
||||
idx = int(raw) - 1
|
||||
if 0 <= idx < len(offers):
|
||||
return offers[idx]
|
||||
print(" Out of range.")
|
||||
else:
|
||||
print(" Invalid input.")
|
||||
|
||||
# ── template selection ────────────────────────────────────────────
|
||||
|
||||
# Decide which Vast template hash to use: CLI flag, interactive pick, or config default
|
||||
def _resolve_template(self, opts: RunOptions) -> str | None:
|
||||
if opts.template_hash:
|
||||
return opts.template_hash
|
||||
if opts.select_template:
|
||||
try:
|
||||
return self._choose_template()
|
||||
except (TemplateSelectionAborted, KeyboardInterrupt):
|
||||
print("Template selection aborted, using default image.")
|
||||
return None
|
||||
return self.config.instance.get("template_hash_id")
|
||||
|
||||
# Interactive prompt — user picks a Docker image template
|
||||
def _choose_template(self) -> str | None:
|
||||
print("Available templates:")
|
||||
for idx, tpl in enumerate(self.TEMPLATE_CATALOG, 1):
|
||||
print(f" [{idx}] {tpl['name']} ({tpl['image'] or 'config image'})")
|
||||
while True:
|
||||
raw = input("Select template [number / q]: ").strip().lower()
|
||||
if raw == "q":
|
||||
raise TemplateSelectionAborted
|
||||
if raw.isdigit():
|
||||
idx = int(raw) - 1
|
||||
if 0 <= idx < len(self.TEMPLATE_CATALOG):
|
||||
tpl = self.TEMPLATE_CATALOG[idx]
|
||||
if tpl["hash_id"] == "none":
|
||||
return None
|
||||
if tpl["image"]:
|
||||
self.config.instance["image"] = tpl["image"]
|
||||
print(f" Selected: {tpl['name']}")
|
||||
return tpl["hash_id"]
|
||||
print(" Invalid input.")
|
||||
|
||||
# ── instance lifecycle ────────────────────────────────────────────
|
||||
|
||||
# Build a unique instance label from a prefix, stem, and UTC timestamp
|
||||
def _label_for(self, stem: str) -> str:
|
||||
ts = datetime.now(UTC).strftime("%Y%m%d-%H%M%S")
|
||||
return f"{self.config.instance['label_prefix']}-{stem}-{ts}"
|
||||
|
||||
# Assemble the JSON payload sent to Vast.ai when renting an instance
|
||||
def _build_instance_payload(self, label: str, *, template_hash_id: str | None = None) -> dict[str, Any]:
|
||||
inst = self.config.instance
|
||||
payload: dict[str, Any] = {
|
||||
"image": inst["image"],
|
||||
"disk": inst["disk_gb"],
|
||||
"runtype": "ssh_direc ssh_proxy",
|
||||
"target_state": inst["target_state"],
|
||||
"label": label,
|
||||
}
|
||||
if template_hash_id:
|
||||
payload["template_hash_id"] = template_hash_id
|
||||
return payload
|
||||
|
||||
# Poll the API until the instance is running and has SSH credentials
|
||||
def _wait_for_instance(self, instance_id: int) -> VastInstance:
|
||||
deadline = time.time() + self.config.remote["ssh_timeout_seconds"]
|
||||
poll = self.config.remote["poll_interval_seconds"]
|
||||
while time.time() < deadline:
|
||||
instance = self.api.show_instance(instance_id)
|
||||
if instance.actual_status == "running" and instance.ssh_host and instance.ssh_port:
|
||||
return instance
|
||||
print(f" instance status: {instance.actual_status or 'pending'}...")
|
||||
time.sleep(poll)
|
||||
raise VastApiError(f"Timed out waiting for instance {instance_id} to be SSH-ready.")
|
||||
|
||||
# Build an SshTarget from the running instance's connection details
|
||||
def _build_target(self, instance: VastInstance) -> SshTarget:
|
||||
return SshTarget(
|
||||
user=self.config.remote["ssh_user"],
|
||||
host=instance.ssh_host or instance.public_ipaddr or "",
|
||||
port=int(instance.ssh_port or 22),
|
||||
private_key=self.private_key,
|
||||
)
|
||||
|
||||
# ── manifest ──────────────────────────────────────────────────────
|
||||
|
||||
# Persist the run manifest as a timestamped JSON file under the output directory
|
||||
def _write_manifest(self, manifest: RunManifest, local_output_root: Path) -> Path:
|
||||
d = local_output_root / "pipeline"
|
||||
d.mkdir(parents=True, exist_ok=True)
|
||||
ts = manifest.created_at.replace(":", "").replace("-", "")
|
||||
path = d / f"{ts}.json"
|
||||
with open(path, "w", encoding="utf-8") as fh:
|
||||
json.dump(asdict(manifest), fh, indent=2)
|
||||
return path
|
||||
|
||||
# ── remote workspace setup ────────────────────────────────────────
|
||||
|
||||
# Upload project files, bootstrap the venv, and ensure training data is present
|
||||
def _prepare_remote_workspace(
|
||||
self, target: SshTarget, *, download_data: bool, send_cropped: bool,
|
||||
first_config: Path | None = None,
|
||||
) -> None:
|
||||
ws = self.config.remote["workspace_dir"]
|
||||
poll = self.config.remote["poll_interval_seconds"]
|
||||
|
||||
print(" Syncing project files...")
|
||||
ssh(target, shell_join(["mkdir", "-p", ws]))
|
||||
rsync_project(self.project_root, target, ws, self.exclude_file)
|
||||
|
||||
print(" Bootstrapping remote environment...")
|
||||
ssh(target, shell_join(["bash", "-lc", f"cd {ws} && bash pipeline/scripts/bootstrap_env.sh"]))
|
||||
|
||||
if send_cropped and first_config is not None:
|
||||
# Determine which cropped subdirectory to send based on the first config
|
||||
first_mod, _ = self._module_for_config(first_config)
|
||||
crop_subdir = "generator" if first_mod.startswith("generator/") else "classifier"
|
||||
local_zip = self.project_root / f"cropped_{crop_subdir}.zip"
|
||||
if not local_zip.exists():
|
||||
print(f" Warning: --send-cropped set but cropped_{crop_subdir}.zip not found.")
|
||||
print(f" Create it with: zip -r cropped_{crop_subdir}.zip cropped/{crop_subdir}/")
|
||||
else:
|
||||
size_mb = local_zip.stat().st_size / (1024 * 1024)
|
||||
print(f" Sending cropped_{crop_subdir}.zip ({size_mb:.0f} MB)...")
|
||||
rsync_file(local_zip, target, ws)
|
||||
print(" Unzipping on remote...")
|
||||
ssh(target, shell_join(["bash", "-lc", f"cd {ws} && unzip -q -o cropped_{crop_subdir}.zip"]))
|
||||
print(f" Pre-cropped {crop_subdir} images ready.")
|
||||
|
||||
data_ready = self._remote_data_exists(target, f"{ws}/data")
|
||||
if download_data or not data_ready:
|
||||
# Download the dataset in a detached process so the SSH timeout doesn't kill it
|
||||
print(" Downloading dataset from HuggingFace...")
|
||||
log = f"{ws}/.pipeline_logs/fetch_ds.log"
|
||||
ssh(target, f"mkdir -p {ws}/.pipeline_logs")
|
||||
pid = run_detached(
|
||||
target,
|
||||
f"cd {ws} && source .venv/bin/activate && python3 classifier/tools/fetch_ds.py",
|
||||
log,
|
||||
)
|
||||
reconnects = 0
|
||||
# Poll until the fetch process exits, with reconnect logic for SSH drops
|
||||
while True:
|
||||
try:
|
||||
if not is_process_alive(target, pid):
|
||||
break
|
||||
reconnects = 0
|
||||
except RemoteCommandError as exc:
|
||||
if "Command failed (255)" not in str(exc) or reconnects >= 10:
|
||||
raise
|
||||
reconnects += 1
|
||||
time.sleep(min(5 * reconnects, 30))
|
||||
continue
|
||||
time.sleep(poll)
|
||||
print(" Dataset ready.")
|
||||
|
||||
# ── config routing ────────────────────────────────────────────────
|
||||
|
||||
# Determine which training script and output dir to use based on config location
|
||||
def _module_for_config(self, config_path: Path) -> tuple[str, str]:
|
||||
try:
|
||||
rel = config_path.resolve().relative_to(self.project_root)
|
||||
except ValueError:
|
||||
rel = config_path
|
||||
if rel.parts[0] == "generator":
|
||||
return "generator/run.py", "generator/outputs"
|
||||
return "classifier/run.py", "classifier/outputs"
|
||||
|
||||
# Merge two dicts recursively (override wins on leaf keys)
|
||||
@staticmethod
|
||||
def _deep_merge_dicts(base: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]:
|
||||
result = base.copy()
|
||||
for key, value in override.items():
|
||||
if key in result and isinstance(result[key], dict) and isinstance(value, dict):
|
||||
result[key] = EphemeralVastRunner._deep_merge_dicts(result[key], value)
|
||||
else:
|
||||
result[key] = value
|
||||
return result
|
||||
|
||||
# Recursively resolve "extends" references in JSON configs with cycle detection
|
||||
def _load_config_with_extends(self, config_path: Path, seen: set[Path] | None = None) -> dict[str, Any]:
|
||||
if seen is None:
|
||||
seen = set()
|
||||
resolved = config_path.resolve()
|
||||
if resolved in seen:
|
||||
raise ValueError(f"Circular config inheritance detected at: {config_path}")
|
||||
seen.add(resolved)
|
||||
with open(config_path, encoding="utf-8") as fh:
|
||||
cfg = json.load(fh)
|
||||
base_ref = cfg.pop("extends", None)
|
||||
if not base_ref:
|
||||
seen.remove(resolved)
|
||||
return cfg
|
||||
base_path = (config_path.parent / base_ref).resolve()
|
||||
base_cfg = self._load_config_with_extends(base_path, seen=seen)
|
||||
seen.remove(resolved)
|
||||
return self._deep_merge_dicts(base_cfg, cfg)
|
||||
|
||||
# Return a stable signature used to detect duplicate training configs
|
||||
def _normalized_config_signature(self, config_path: Path) -> str:
|
||||
run_script, _ = self._module_for_config(config_path)
|
||||
# Generator configs do not currently use shared/extends inheritance.
|
||||
if run_script.startswith("generator/"):
|
||||
with open(config_path, encoding="utf-8") as fh:
|
||||
cfg = json.load(fh)
|
||||
else:
|
||||
cfg = self._load_config_with_extends(config_path)
|
||||
shared_path = config_path.parent.parent / "shared.json"
|
||||
if shared_path.exists():
|
||||
with open(shared_path, encoding="utf-8") as fh:
|
||||
shared_cfg = json.load(fh)
|
||||
cfg = self._deep_merge_dicts(shared_cfg, cfg)
|
||||
|
||||
# run_name should not influence whether two configs are equivalent to train.
|
||||
cfg.pop("run_name", None)
|
||||
return json.dumps(cfg, sort_keys=True, separators=(",", ":"))
|
||||
|
||||
# Detect configs that are pure extends (only run_name + extends, no new training settings)
|
||||
# These are pointers to an already-trained experiment and should be skipped
|
||||
def _is_pure_extend(self, config_path: Path) -> str | None:
|
||||
with open(config_path, encoding="utf-8") as fh:
|
||||
raw = json.load(fh)
|
||||
if "extends" not in raw:
|
||||
return None
|
||||
non_meta = {k for k in raw if k not in ("run_name", "extends")}
|
||||
if non_meta:
|
||||
return None
|
||||
base_path = (config_path.parent / raw["extends"]).resolve()
|
||||
return str(base_path.relative_to(self.project_root))
|
||||
|
||||
def _dedupe_training_configs(self, config_paths: list[Path]) -> list[Path]:
|
||||
seen: dict[str, Path] = {}
|
||||
deduped: list[Path] = []
|
||||
for cp in config_paths:
|
||||
pure_extends = self._is_pure_extend(cp)
|
||||
if pure_extends is not None:
|
||||
print(f"Skipping {cp.name} (pure extend of {pure_extends})")
|
||||
continue
|
||||
sig = self._normalized_config_signature(cp)
|
||||
if sig in seen:
|
||||
first = seen[sig]
|
||||
print(f"Skipping duplicate config {cp.name} (same training settings as {first.name})")
|
||||
continue
|
||||
seen[sig] = cp
|
||||
deduped.append(cp)
|
||||
return deduped
|
||||
|
||||
# ── remote directory checks ───────────────────────────────────────
|
||||
|
||||
# Check whether a directory exists on the remote host
|
||||
def _remote_dir_exists(self, target: SshTarget, path: str) -> bool:
|
||||
out = ssh_output(target, f"if [ -d {path} ]; then echo yes; else echo no; fi").strip()
|
||||
return out == "yes"
|
||||
|
||||
# Check that the dataset is populated (not just an empty directory)
|
||||
def _remote_data_exists(self, target: SshTarget, data_dir: str) -> bool:
|
||||
out = ssh_output(target, f"if [ -d {data_dir}/wiki ]; then echo yes; else echo no; fi").strip()
|
||||
return out == "yes"
|
||||
|
||||
# ── training ─────────────────────────────────────────────────────
|
||||
|
||||
# Launch the training script inside nohup, stream logs, and handle reconnects
|
||||
def _run_training(self, target: SshTarget, config_path: Path, opts: RunOptions) -> None:
|
||||
ws = self.config.remote["workspace_dir"]
|
||||
config_rel = config_path.resolve().relative_to(self.project_root)
|
||||
run_script, output_root = self._module_for_config(config_path)
|
||||
is_generator = run_script.startswith("generator/")
|
||||
|
||||
run_cmd = shell_join(["python3", "-u", run_script, str(config_rel), "--output-root", output_root])
|
||||
if opts.use_gpu:
|
||||
run_cmd += " --use-gpu"
|
||||
|
||||
log_dir = f"{ws}/.pipeline_logs"
|
||||
log_path = f"{log_dir}/{config_path.stem}.log"
|
||||
exit_path = f"{log_dir}/{config_path.stem}.exit"
|
||||
|
||||
ssh(target, f"mkdir -p {log_dir}")
|
||||
|
||||
# Run training under nohup; write the exit code to a marker file when done
|
||||
pid = run_detached(
|
||||
target,
|
||||
f"cd {ws} && source .venv/bin/activate && {run_cmd}; echo $? > {exit_path}",
|
||||
log_path,
|
||||
)
|
||||
print(f" Training started (PID {pid})")
|
||||
|
||||
poll = self.config.remote["poll_interval_seconds"]
|
||||
max_reconnects = 10
|
||||
reconnects = 0
|
||||
next_sync = _GENERATOR_SYNC_INTERVAL if is_generator else None
|
||||
local_output_root = self.project_root / output_root
|
||||
stop_streaming = threading.Event()
|
||||
|
||||
# Background thread: tail the remote log and print tqdm-aware output
|
||||
def _stream_worker() -> None:
|
||||
tail_proc = None
|
||||
last_was_progress = False
|
||||
try:
|
||||
time.sleep(1) # wait for the log file to be created
|
||||
tail_proc = tail_log(target, log_path)
|
||||
for raw in tail_proc.stdout:
|
||||
if stop_streaming.is_set():
|
||||
break
|
||||
# tqdm uses \r for in-place updates; take only the last segment
|
||||
line = raw.rstrip("\r\n").rsplit("\r", 1)[-1].rstrip()
|
||||
if not line:
|
||||
continue
|
||||
is_progress = "it/s]" in line or "%|" in line
|
||||
if is_progress:
|
||||
print(f"\r {line}", end="", flush=True)
|
||||
last_was_progress = True
|
||||
else:
|
||||
if last_was_progress:
|
||||
print(flush=True)
|
||||
last_was_progress = False
|
||||
print(f" {line}", flush=True)
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
if last_was_progress:
|
||||
print(flush=True)
|
||||
if tail_proc is not None:
|
||||
tail_proc.terminate()
|
||||
tail_proc.wait(timeout=5)
|
||||
|
||||
stream_thread = threading.Thread(target=_stream_worker, daemon=True)
|
||||
stream_thread.start()
|
||||
|
||||
try:
|
||||
# Main polling loop: check if the process is alive, periodically sync generator outputs
|
||||
while True:
|
||||
try:
|
||||
alive = is_process_alive(target, pid)
|
||||
except RemoteCommandError as exc:
|
||||
if "Command failed (255)" not in str(exc) or reconnects >= max_reconnects:
|
||||
raise
|
||||
reconnects += 1
|
||||
print(f"\n SSH dropped (reconnect {reconnects}/{max_reconnects}), "
|
||||
"training continues on remote...")
|
||||
time.sleep(min(5 * reconnects, 30))
|
||||
continue
|
||||
|
||||
if is_generator and next_sync is not None:
|
||||
epoch = self._latest_epoch_in_log(target, log_path)
|
||||
if epoch >= next_sync:
|
||||
print(f" Syncing generator outputs at epoch {next_sync}...")
|
||||
self._fetch_outputs(target, output_root, local_output_root)
|
||||
next_sync += _GENERATOR_SYNC_INTERVAL
|
||||
|
||||
if not alive:
|
||||
break
|
||||
|
||||
reconnects = 0
|
||||
time.sleep(poll)
|
||||
finally:
|
||||
stop_streaming.set()
|
||||
stream_thread.join(timeout=10)
|
||||
|
||||
# Read the exit code marker that the nohup wrapper wrote
|
||||
raw_exit = read_remote_file(target, exit_path)
|
||||
exit_code = int(raw_exit) if raw_exit is not None else -1
|
||||
|
||||
if exit_code != 0:
|
||||
if self._remote_dir_exists(target, f"{ws}/{output_root}"):
|
||||
self._fetch_outputs(target, output_root, local_output_root)
|
||||
log_tail = read_remote_file(target, log_path)
|
||||
snippet = "\n".join((log_tail or "").splitlines()[-30:])
|
||||
raise RemoteCommandError(f"Training failed (exit {exit_code}).\n{snippet}")
|
||||
|
||||
# Parse the latest epoch number from the training log (used for generator sync scheduling)
|
||||
@staticmethod
|
||||
def _latest_epoch_in_log(target: SshTarget, log_path: str) -> int:
|
||||
try:
|
||||
out = ssh_output(target, f"tail -n 200 {log_path} 2>/dev/null || true")
|
||||
epochs = [int(m) for m in re.findall(r"\[(\d+)/\d+\]", out)]
|
||||
return max(epochs, default=0)
|
||||
except RemoteCommandError:
|
||||
return 0
|
||||
|
||||
# Download remote outputs with up to 3 retry attempts (exponential back-off)
|
||||
def _fetch_outputs(self, target: SshTarget, remote_output_root: str, local_output_root: Path) -> None:
|
||||
remote_path = f"{self.config.remote['workspace_dir']}/{remote_output_root}"
|
||||
delay = 5
|
||||
for attempt in range(1, 4):
|
||||
try:
|
||||
fetch_directory(target, remote_path, local_output_root)
|
||||
return
|
||||
except RemoteCommandError as exc:
|
||||
if attempt < 3:
|
||||
print(f" Download attempt {attempt} failed: {exc}")
|
||||
print(f" Retrying in {delay}s...")
|
||||
time.sleep(delay)
|
||||
delay *= 2
|
||||
else:
|
||||
raise
|
||||
|
||||
# ── public commands ───────────────────────────────────────────────
|
||||
|
||||
# Full pipeline: validate configs -> find offer -> rent instance -> train -> fetch -> destroy
|
||||
def run(self, config_paths: list[Path], opts: RunOptions) -> None:
|
||||
resolved = []
|
||||
for cp in config_paths:
|
||||
cp = (self.project_root / cp).resolve() if not cp.is_absolute() else cp
|
||||
if not cp.exists():
|
||||
raise FileNotFoundError(f"Config not found: {cp}")
|
||||
resolved.append(cp)
|
||||
|
||||
resolved = self._dedupe_training_configs(resolved)
|
||||
# Abort early if all configs were duplicates
|
||||
if not resolved:
|
||||
raise ValueError("No unique configs to run after deduplication.")
|
||||
|
||||
n = len(resolved)
|
||||
_, first_output_root = self._module_for_config(resolved[0])
|
||||
local_output_root = self.project_root / first_output_root
|
||||
|
||||
self.api.ensure_ssh_key(self.public_key)
|
||||
offers = self._search_offers(opts.sort_mode, opts.region, opts.price_cap)
|
||||
|
||||
if opts.list_regions:
|
||||
self._print_regions(offers)
|
||||
return
|
||||
|
||||
try:
|
||||
offer = self._choose_offer(offers) if opts.select_offer else offers[0]
|
||||
except (OfferSelectionAborted, KeyboardInterrupt):
|
||||
print("Aborted.")
|
||||
return
|
||||
|
||||
if opts.dry_run:
|
||||
print(json.dumps(offer if opts.select_offer else offers[:10], indent=2))
|
||||
return
|
||||
|
||||
label_stem = resolved[0].stem if n == 1 else resolved[0].parent.name
|
||||
template_hash_id = self._resolve_template(opts)
|
||||
offer_id = int(offer["id"])
|
||||
cost_per_hour = float(offer.get("dph_total", 0))
|
||||
|
||||
manifest = RunManifest(
|
||||
created_at=datetime.now(UTC).isoformat(),
|
||||
config_paths=[str(cp.relative_to(self.project_root)) for cp in resolved],
|
||||
instance_id=None,
|
||||
offer_id=offer_id,
|
||||
ssh_host=None,
|
||||
ssh_port=None,
|
||||
status="creating",
|
||||
remote_workspace=self.config.remote["workspace_dir"],
|
||||
)
|
||||
manifest_path = self._write_manifest(manifest, local_output_root)
|
||||
|
||||
instance_id: int | None = None
|
||||
should_destroy = True
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# Build the instance payload and rent the GPU
|
||||
payload = self._build_instance_payload(self._label_for(label_stem),
|
||||
template_hash_id=template_hash_id)
|
||||
print(f"Creating instance from offer {offer_id} (${cost_per_hour:.3f}/h)...")
|
||||
instance_id = self.api.create_instance(offer_id, payload)
|
||||
self.api.attach_ssh_key(instance_id, self.public_key)
|
||||
|
||||
# Record the instance ID in the manifest immediately for debugging
|
||||
manifest.instance_id = instance_id
|
||||
|
||||
# Wait for the instance to be SSH-ready, then update the manifest
|
||||
print(f"Waiting for instance {instance_id}...")
|
||||
instance = self._wait_for_instance(instance_id)
|
||||
manifest.ssh_host = instance.ssh_host
|
||||
manifest.ssh_port = instance.ssh_port
|
||||
manifest.status = instance.actual_status
|
||||
self._write_manifest(manifest, local_output_root)
|
||||
print(f"Instance ready: {instance.ssh_host}:{instance.ssh_port} "
|
||||
f"({instance.gpu_name}, ${instance.dph_total}/h)")
|
||||
|
||||
target = self._build_target(instance)
|
||||
|
||||
# Wait for SSH to become available before running any remote commands
|
||||
print("Waiting for SSH...")
|
||||
wait_for_ssh(
|
||||
target,
|
||||
timeout_seconds=self.config.remote["ssh_timeout_seconds"],
|
||||
poll_interval_seconds=self.config.remote["poll_interval_seconds"],
|
||||
)
|
||||
self._prepare_remote_workspace(
|
||||
target, download_data=opts.download_data, send_cropped=opts.send_cropped,
|
||||
first_config=resolved[0],
|
||||
)
|
||||
|
||||
# Run each training config sequentially on the same instance
|
||||
for i, config_path in enumerate(resolved, 1):
|
||||
print(f"\n[{i}/{n}] Training: {config_path.name}")
|
||||
self._run_training(target, config_path, opts)
|
||||
print(f"[{i}/{n}] Fetching outputs...")
|
||||
_, output_root = self._module_for_config(config_path)
|
||||
self._fetch_outputs(target, output_root, self.project_root / output_root)
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
manifest.status = "completed"
|
||||
self._write_manifest(manifest, local_output_root)
|
||||
print(f"\nAll {n} run(s) completed in {self._fmt_duration(elapsed)} "
|
||||
f"(~{self._fmt_cost(elapsed, cost_per_hour)}). Manifest: {manifest_path}")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
elapsed = time.time() - start_time
|
||||
manifest.status = "cancelled"
|
||||
self._write_manifest(manifest, local_output_root)
|
||||
print(f"\nCancelled after {self._fmt_duration(elapsed)} "
|
||||
f"(~{self._fmt_cost(elapsed, cost_per_hour)}).")
|
||||
|
||||
except Exception:
|
||||
elapsed = time.time() - start_time
|
||||
manifest.status = "failed"
|
||||
self._write_manifest(manifest, local_output_root)
|
||||
print(f" Failed after {self._fmt_duration(elapsed)} "
|
||||
f"(~{self._fmt_cost(elapsed, cost_per_hour)}).")
|
||||
if opts.keep_on_failure or self.config.keep_on_failure:
|
||||
should_destroy = False
|
||||
raise
|
||||
|
||||
finally:
|
||||
if instance_id is not None and should_destroy:
|
||||
elapsed = time.time() - start_time
|
||||
print(f"Destroying instance {instance_id} "
|
||||
f"(total: {self._fmt_duration(elapsed)}, "
|
||||
f"~{self._fmt_cost(elapsed, cost_per_hour)})")
|
||||
self.api.destroy_instance(instance_id)
|
||||
|
||||
# CLI subcommand: search and display available GPU offers without renting
|
||||
def offers(
|
||||
self,
|
||||
*,
|
||||
sort_mode: str | None,
|
||||
region: str | None,
|
||||
price_cap: float | None,
|
||||
select_offer: bool,
|
||||
list_regions: bool,
|
||||
limit_output: int,
|
||||
) -> None:
|
||||
offers = self._search_offers(sort_mode, region, price_cap)
|
||||
if list_regions:
|
||||
self._print_regions(offers)
|
||||
return
|
||||
if select_offer:
|
||||
try:
|
||||
offer = self._choose_offer(offers)
|
||||
except (OfferSelectionAborted, KeyboardInterrupt):
|
||||
print("Aborted.")
|
||||
return
|
||||
print(json.dumps(offer, indent=2))
|
||||
else:
|
||||
print(json.dumps(offers[:limit_output], indent=2))
|
||||
|
||||
# CLI subcommand: rent an instance and print its connection details (no training)
|
||||
def up(self, *, label: str | None, opts: RunOptions | None = None) -> None:
|
||||
if opts is None:
|
||||
opts = RunOptions()
|
||||
self.api.ensure_ssh_key(self.public_key)
|
||||
template_hash_id = self._resolve_template(opts)
|
||||
offer = self._search_offers()[0]
|
||||
offer_id = int(offer["id"])
|
||||
payload = self._build_instance_payload(
|
||||
label or self._label_for("manual"), template_hash_id=template_hash_id
|
||||
)
|
||||
instance_id = self.api.create_instance(offer_id, payload)
|
||||
self.api.attach_ssh_key(instance_id, self.public_key)
|
||||
instance = self._wait_for_instance(instance_id)
|
||||
print(json.dumps({
|
||||
"offer_id": offer_id,
|
||||
"instance_id": instance_id,
|
||||
"ssh_host": instance.ssh_host,
|
||||
"ssh_port": instance.ssh_port,
|
||||
"gpu_name": instance.gpu_name,
|
||||
"dph_total": instance.dph_total,
|
||||
}, indent=2))
|
||||
|
||||
# CLI subcommand: print the raw instance status JSON
|
||||
def status(self, instance_id: int) -> None:
|
||||
print(json.dumps(self.api.show_instance(instance_id).raw, indent=2))
|
||||
|
||||
# CLI subcommand: destroy a running instance by ID
|
||||
def down(self, instance_id: int) -> None:
|
||||
self.api.destroy_instance(instance_id)
|
||||
print(f"Destroyed instance {instance_id}")
|
||||
@@ -0,0 +1,180 @@
|
||||
import shlex
|
||||
import subprocess
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
# Error raised when a remote SSH or rsync command fails
|
||||
class RemoteCommandError(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
# Connection details for an SSH target (Vast instance)
|
||||
@dataclass(slots=True)
|
||||
class SshTarget:
|
||||
user: str
|
||||
host: str
|
||||
port: int
|
||||
private_key: Path
|
||||
|
||||
# user@host string used by SSH and rsync
|
||||
@property
|
||||
def destination(self) -> str:
|
||||
return f"{self.user}@{self.host}"
|
||||
|
||||
# Base SSH flags: identity, port, keep-alive, and auto-accept host keys
|
||||
def ssh_base_command(self) -> list[str]:
|
||||
return [
|
||||
"ssh",
|
||||
"-i", str(self.private_key),
|
||||
"-p", str(self.port),
|
||||
"-o", "BatchMode=yes",
|
||||
"-o", "StrictHostKeyChecking=accept-new",
|
||||
"-o", "ServerAliveInterval=30",
|
||||
"-o", "ServerAliveCountMax=10",
|
||||
"-o", "TCPKeepAlive=yes",
|
||||
self.destination,
|
||||
]
|
||||
|
||||
# SSH flags formatted as a string for rsync's -e option
|
||||
def rsync_ssh_opts(self) -> str:
|
||||
return " ".join(self.ssh_base_command()[:-1])
|
||||
|
||||
|
||||
# ── low-level subprocess helpers ─────────────────────────────────────
|
||||
|
||||
# Run a local subprocess, raising RemoteCommandError on non-zero exit
|
||||
def _run(command: list[str]) -> None:
|
||||
result = subprocess.run(command, check=False, stderr=subprocess.PIPE, text=True)
|
||||
if result.returncode != 0:
|
||||
snippet = result.stderr.strip()[-500:] if result.stderr else ""
|
||||
msg = f"Command failed ({result.returncode}): {' '.join(command)}"
|
||||
raise RemoteCommandError(f"{msg}\n{snippet}" if snippet else msg)
|
||||
|
||||
|
||||
# ── SSH ───────────────────────────────────────────────────────────────
|
||||
|
||||
# Repeatedly attempt an SSH connection until the target accepts it or the deadline passes
|
||||
def wait_for_ssh(target: SshTarget, *, timeout_seconds: int, poll_interval_seconds: int) -> None:
|
||||
deadline = time.time() + timeout_seconds
|
||||
while time.time() < deadline:
|
||||
result = subprocess.run(
|
||||
target.ssh_base_command() + ["true"],
|
||||
check=False, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL,
|
||||
)
|
||||
if result.returncode == 0:
|
||||
return
|
||||
remaining = int(deadline - time.time())
|
||||
print(f" SSH not ready yet, retrying... ({remaining}s remaining)")
|
||||
time.sleep(poll_interval_seconds)
|
||||
raise RemoteCommandError(
|
||||
f"Timed out waiting for SSH on {target.host}:{target.port} after {timeout_seconds}s."
|
||||
)
|
||||
|
||||
|
||||
# Execute a single command on the remote host via SSH
|
||||
def ssh(target: SshTarget, command: str) -> None:
|
||||
_run(target.ssh_base_command() + [command])
|
||||
|
||||
|
||||
# Execute a remote command and return its stdout
|
||||
def ssh_output(target: SshTarget, command: str) -> str:
|
||||
result = subprocess.run(
|
||||
target.ssh_base_command() + [command],
|
||||
check=False, capture_output=True, text=True,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
raise RemoteCommandError(
|
||||
f"Command failed ({result.returncode}): {command}\n{result.stderr.strip()}"
|
||||
)
|
||||
return result.stdout
|
||||
|
||||
|
||||
# ── detached process management ───────────────────────────────────────
|
||||
|
||||
# Launch a remote command under nohup so it survives SSH drops; return PID
|
||||
def run_detached(target: SshTarget, command: str, log_path: str) -> int:
|
||||
inner = f"nohup bash -c {shlex.quote(command)} > {shlex.quote(log_path)} 2>&1 & echo $!"
|
||||
output = ssh_output(target, inner)
|
||||
return int(output.strip().splitlines()[-1])
|
||||
|
||||
|
||||
# Return True if the remote process is running
|
||||
# Raise RemoteCommandError on SSH failure (exit 255) so callers can apply
|
||||
# reconnect logic and only return False when kill -0 confirms the PID is gone
|
||||
def is_process_alive(target: SshTarget, pid: int) -> bool:
|
||||
try:
|
||||
ssh(target, f"kill -0 {pid}")
|
||||
return True
|
||||
except RemoteCommandError as exc:
|
||||
if "Command failed (255)" in str(exc):
|
||||
raise # SSH failure, let caller handle
|
||||
return False # process not found
|
||||
|
||||
|
||||
# Read a remote file's contents; returns None if the file does not exist
|
||||
def read_remote_file(target: SshTarget, path: str) -> str | None:
|
||||
return ssh_output(target, f"cat {shlex.quote(path)} 2>/dev/null || true").strip() or None
|
||||
|
||||
|
||||
# Open a persistent tail -f on a remote file for real-time streaming
|
||||
def tail_log(target: SshTarget, log_path: str) -> subprocess.Popen:
|
||||
cmd = target.ssh_base_command() + [f"tail -f -n +1 {shlex.quote(log_path)}"]
|
||||
return subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL, text=True)
|
||||
|
||||
|
||||
# ── rsync ─────────────────────────────────────────────────────────────
|
||||
|
||||
# Upload the entire project directory, respecting the exclude list
|
||||
def rsync_project(local_root: Path, target: SshTarget, remote_root: str, exclude_file: Path) -> None:
|
||||
_run([
|
||||
"rsync", "-az", "--info=progress2", "--delete",
|
||||
"--exclude-from", str(exclude_file),
|
||||
"-e", target.rsync_ssh_opts(),
|
||||
f"{local_root}/",
|
||||
f"{target.destination}:{remote_root}/",
|
||||
])
|
||||
|
||||
|
||||
# Upload a single file to a remote directory
|
||||
def rsync_file(local_file: Path, target: SshTarget, remote_dir: str) -> None:
|
||||
_run([
|
||||
"rsync", "-az", "--partial", "--info=progress2",
|
||||
"-e", target.rsync_ssh_opts(),
|
||||
str(local_file),
|
||||
f"{target.destination}:{remote_dir}/",
|
||||
])
|
||||
|
||||
|
||||
# Download a remote directory, verify file count, and resume partial transfers
|
||||
def fetch_directory(target: SshTarget, remote_dir: str, local_dir: Path) -> None:
|
||||
local_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Count remote files first so we can verify completeness after transfer
|
||||
remote_count = int(ssh_output(
|
||||
target, f"find {remote_dir.rstrip('/')} -type f 2>/dev/null | wc -l"
|
||||
).strip())
|
||||
|
||||
_run([
|
||||
"rsync", "-az", "--partial", "--append-verify", "--info=progress2",
|
||||
"-e", target.rsync_ssh_opts(),
|
||||
f"{target.destination}:{remote_dir.rstrip('/')}/",
|
||||
f"{local_dir}/",
|
||||
])
|
||||
|
||||
local_count = sum(1 for p in local_dir.rglob("*") if p.is_file())
|
||||
# Raise if fewer files arrived than expected (corrupted / interrupted transfer)
|
||||
if local_count < remote_count:
|
||||
raise RemoteCommandError(
|
||||
f"Download incomplete: got {local_count} files, expected {remote_count}. "
|
||||
f"Remote: {remote_dir} Local: {local_dir}"
|
||||
)
|
||||
print(f" Verified: {local_count} files downloaded successfully")
|
||||
|
||||
|
||||
# ── misc ──────────────────────────────────────────────────────────────
|
||||
|
||||
# Quote and join shell arguments for safe remote execution
|
||||
def shell_join(parts: list[str]) -> str:
|
||||
return " ".join(shlex.quote(p) for p in parts)
|
||||
@@ -0,0 +1,15 @@
|
||||
|
||||
.claude/
|
||||
.git/
|
||||
|
||||
.venv/
|
||||
__pycache__/
|
||||
.ipynb_checkpoints/
|
||||
|
||||
/data/
|
||||
/cropped/
|
||||
/cropped_classifier.zip
|
||||
/cropped_generator.zip
|
||||
|
||||
/classifier/outputs/
|
||||
/generator/outputs/
|
||||
@@ -0,0 +1,48 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)"
|
||||
cd "$ROOT_DIR"
|
||||
|
||||
PYTHON_BIN="${PYTHON_BIN:-python3}"
|
||||
USE_SYSTEM_SITE_PACKAGES="${USE_SYSTEM_SITE_PACKAGES:-1}"
|
||||
SKIP_TORCH_INSTALL="${SKIP_TORCH_INSTALL:-1}"
|
||||
|
||||
VENV_ARGS=()
|
||||
if [ "$USE_SYSTEM_SITE_PACKAGES" = "1" ]; then
|
||||
VENV_ARGS+=(--system-site-packages)
|
||||
fi
|
||||
|
||||
"$PYTHON_BIN" -m venv "${VENV_ARGS[@]}" .venv
|
||||
source .venv/bin/activate
|
||||
|
||||
python -m pip install --upgrade pip setuptools wheel
|
||||
|
||||
# Capture system torch versions before venv installs override them
|
||||
TORCH_VERSION="$($PYTHON_BIN -c 'import torch; print(torch.__version__)' 2>/dev/null || true)"
|
||||
TV_VERSION="$($PYTHON_BIN -c 'import torchvision; print(torchvision.__version__)' 2>/dev/null || true)"
|
||||
|
||||
if [ "$SKIP_TORCH_INSTALL" = "1" ]; then
|
||||
filtered_requirements="$(mktemp)"
|
||||
grep -Ev '^(torch|torchvision)([<>=!~].*)?$' requirements.txt > "$filtered_requirements"
|
||||
|
||||
# Pin system torch/torchvision so transitive deps (e.g. facenet-pytorch) can't downgrade them
|
||||
constraints_file="$(mktemp)"
|
||||
[ -n "$TORCH_VERSION" ] && echo "torch==$TORCH_VERSION" >> "$constraints_file"
|
||||
[ -n "$TV_VERSION" ] && echo "torchvision==$TV_VERSION" >> "$constraints_file"
|
||||
python -m pip install -c "$constraints_file" -r "$filtered_requirements"
|
||||
rm -f "$filtered_requirements" "$constraints_file"
|
||||
else
|
||||
python -m pip install -r requirements.txt
|
||||
fi
|
||||
|
||||
mkdir -p classifier/outputs/logs classifier/outputs/models classifier/outputs/analysis classifier/outputs/figures classifier/outputs/pipeline
|
||||
|
||||
python - <<'PY'
|
||||
try:
|
||||
import torch
|
||||
print(f"torch={torch.__version__} cuda_available={torch.cuda.is_available()}")
|
||||
except Exception as exc:
|
||||
print(f"torch check failed: {exc}")
|
||||
PY
|
||||
|
||||
@@ -0,0 +1,131 @@
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from urllib import error, request
|
||||
|
||||
|
||||
# Generic error raised for any Vast.ai API failure
|
||||
class VastApiError(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
# Lightweight view of a Vast.ai instance with the fields the pipeline cares about
|
||||
@dataclass(slots=True)
|
||||
class VastInstance:
|
||||
id: int
|
||||
actual_status: str
|
||||
ssh_host: str | None
|
||||
ssh_port: int | None
|
||||
public_ipaddr: str | None
|
||||
gpu_name: str | None
|
||||
dph_total: float | None
|
||||
raw: dict[str, Any]
|
||||
|
||||
|
||||
# Thin wrapper around the Vast.ai REST API
|
||||
class VastApiClient:
|
||||
def __init__(self, api_key: str, *, base_url: str = "https://console.vast.ai") -> None:
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url.rstrip("/")
|
||||
|
||||
# Low-level request helper — sends JSON, returns parsed response body
|
||||
def _request(self, method: str, path: str, payload: dict[str, Any] | None = None) -> Any:
|
||||
url = f"{self.base_url}{path}"
|
||||
data = None
|
||||
headers = {"Authorization": f"Bearer {self.api_key}"}
|
||||
if payload is not None:
|
||||
headers["Content-Type"] = "application/json"
|
||||
data = json.dumps(payload).encode("utf-8")
|
||||
|
||||
req = request.Request(url, method=method, data=data, headers=headers)
|
||||
try:
|
||||
with request.urlopen(req, timeout=60) as response:
|
||||
body = response.read()
|
||||
|
||||
# Vast.ai returns varied error formats; surface whatever body we get
|
||||
except error.HTTPError as exc:
|
||||
details = exc.read().decode("utf-8", errors="replace")
|
||||
raise VastApiError(f"{method} {path} failed with {exc.code}: {details}") from exc
|
||||
except error.URLError as exc:
|
||||
raise VastApiError(f"{method} {path} failed: {exc.reason}") from exc
|
||||
|
||||
if not body:
|
||||
return None
|
||||
return json.loads(body)
|
||||
|
||||
# Fetch the currently authenticated user's profile
|
||||
def show_user(self) -> dict[str, Any]:
|
||||
return self._request("GET", "/api/v0/users/current/")
|
||||
|
||||
# ── SSH keys ────────────────────────────────────────────────────────
|
||||
|
||||
# List registered SSH keys; handles inconsistent response shapes from the API
|
||||
def show_ssh_keys(self) -> list[dict[str, Any]]:
|
||||
response = self._request("GET", "/api/v0/ssh/")
|
||||
if isinstance(response, list):
|
||||
return response
|
||||
if isinstance(response, dict):
|
||||
for key in ("keys", "ssh_keys"):
|
||||
value = response.get(key)
|
||||
if isinstance(value, list):
|
||||
return value
|
||||
raise VastApiError(f"Unexpected SSH key response: {response}")
|
||||
|
||||
# Register the public key if it isn't already present
|
||||
def ensure_ssh_key(self, public_key: str) -> None:
|
||||
existing_keys = self.show_ssh_keys()
|
||||
if any(
|
||||
(
|
||||
item.get("key")
|
||||
or item.get("public_key")
|
||||
or item.get("ssh_key")
|
||||
or ""
|
||||
).strip() == public_key
|
||||
for item in existing_keys
|
||||
):
|
||||
return
|
||||
self._request("POST", "/api/v0/ssh/", {"ssh_key": public_key})
|
||||
|
||||
# Authorise an SSH key for a running instance
|
||||
def attach_ssh_key(self, instance_id: int, public_key: str) -> None:
|
||||
self._request("POST", f"/api/v0/instances/{instance_id}/ssh/", {"ssh_key": public_key})
|
||||
|
||||
# ── Offers ─────────────────────────────────────────────────────────
|
||||
|
||||
# Search available GPU offers matching a query filter
|
||||
def search_offers(self, query: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
response = self._request("POST", "/api/v0/bundles/", query)
|
||||
offers = response.get("offers", [])
|
||||
if isinstance(offers, dict):
|
||||
return [offers]
|
||||
return offers
|
||||
|
||||
# ── Instances ──────────────────────────────────────────────────────
|
||||
|
||||
# Rent an offer, returning the new contract (instance) ID
|
||||
def create_instance(self, offer_id: int, payload: dict[str, Any]) -> int:
|
||||
response = self._request("PUT", f"/api/v0/asks/{offer_id}/", payload)
|
||||
if not response or not response.get("success"):
|
||||
raise VastApiError(f"Instance creation failed for offer {offer_id}: {response}")
|
||||
return int(response["new_contract"])
|
||||
|
||||
# Fetch current status and connection details for an instance
|
||||
def show_instance(self, instance_id: int) -> VastInstance:
|
||||
response = self._request("GET", f"/api/v0/instances/{instance_id}/")
|
||||
raw = response.get("instances")
|
||||
if not raw:
|
||||
raise VastApiError(f"No instance details found for {instance_id}: {response}")
|
||||
return VastInstance(
|
||||
id=int(raw["id"]),
|
||||
actual_status=raw.get("actual_status", ""),
|
||||
ssh_host=raw.get("ssh_host"),
|
||||
ssh_port=raw.get("ssh_port"),
|
||||
public_ipaddr=raw.get("public_ipaddr"),
|
||||
gpu_name=raw.get("gpu_name"),
|
||||
dph_total=raw.get("dph_total"),
|
||||
raw=raw,
|
||||
)
|
||||
|
||||
# Permanently destroy an instance (releases the GPU and billing)
|
||||
def destroy_instance(self, instance_id: int) -> None:
|
||||
self._request("DELETE", f"/api/v0/instances/{instance_id}/")
|
||||
Reference in New Issue
Block a user