183 lines
6.6 KiB
Python
183 lines
6.6 KiB
Python
import json
|
|
import os
|
|
from dataclasses import dataclass, field
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
|
|
# Load environment variables from pipeline/.env if present
|
|
# This file is intentionally optional and gitignored
|
|
# Expected keys:
|
|
# VAST_API_KEY=<your-api-key>
|
|
# VAST_SSH_PRIVATE_KEY=/home/your-user/.ssh/id_ed25519
|
|
def load_dotenv(dotenv_path: Path) -> None:
|
|
if not dotenv_path.exists():
|
|
return
|
|
for raw_line in dotenv_path.read_text(encoding="utf-8").splitlines():
|
|
line = raw_line.strip()
|
|
if not line or line.startswith("#"):
|
|
continue
|
|
if "=" not in line:
|
|
continue
|
|
key, value = line.split("=", 1)
|
|
key = key.strip()
|
|
value = value.strip().strip("'").strip('"')
|
|
os.environ.setdefault(key, value)
|
|
|
|
# Default offer search settings
|
|
DEFAULT_SEARCH: dict[str, Any] = {
|
|
"limit": 100,
|
|
"order_by": "dlperf",
|
|
"order_direction": "desc",
|
|
"sort_mode": "performance",
|
|
"offer_type": "ondemand",
|
|
"verified_only": True,
|
|
"rentable": True,
|
|
"rented": False,
|
|
"num_gpus": 1,
|
|
"min_reliability": 0.98,
|
|
"min_cuda_ram_mb": 0,
|
|
"min_cpu_ram_mb": 0,
|
|
"min_disk_space_gb": 0,
|
|
"min_direct_port_count": 1,
|
|
"max_dph_total": 0.20,
|
|
"gpu_names": ["RTX 3090", "RTX 3090 Ti"],
|
|
"countries_exclude": [],
|
|
}
|
|
|
|
# Default instance settings
|
|
DEFAULT_INSTANCE: dict[str, Any] = {
|
|
"image": "vastai/pytorch:latest",
|
|
"disk_gb": 48,
|
|
"target_state": "running",
|
|
"label_prefix": "drl-proj",
|
|
"template_hash_id": None,
|
|
}
|
|
|
|
# Default remote SSH/workspace settings
|
|
DEFAULT_REMOTE: dict[str, Any] = {
|
|
"ssh_user": "root",
|
|
"workspace_dir": "/workspace/DRL_PROJ",
|
|
"remote_data_dir": "/workspace/data/DeepFakeFace",
|
|
"remote_output_root": "classifier/outputs",
|
|
"ssh_timeout_seconds": 900,
|
|
"poll_interval_seconds": 10,
|
|
}
|
|
|
|
# Default local transfer paths
|
|
DEFAULT_TRANSFER: dict[str, Any] = {
|
|
"local_output_dir": "classifier/outputs",
|
|
"local_data_dir": "data",
|
|
}
|
|
|
|
|
|
# Merge nested dictionaries recursively
|
|
# Used when a user override file should replace only specific nested keys
|
|
# (for example, overriding just search.max_dph_total without redefining search)
|
|
def _deep_merge(base: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]:
|
|
result = dict(base)
|
|
for key, value in override.items():
|
|
if isinstance(value, dict) and isinstance(result.get(key), dict):
|
|
result[key] = _deep_merge(result[key], value)
|
|
else:
|
|
result[key] = value
|
|
return result
|
|
|
|
|
|
# Runtime pipeline config container and loader
|
|
@dataclass(slots=True)
|
|
class PipelineConfig:
|
|
search: dict[str, Any] = field(default_factory=lambda: dict(DEFAULT_SEARCH))
|
|
instance: dict[str, Any] = field(default_factory=lambda: dict(DEFAULT_INSTANCE))
|
|
remote: dict[str, Any] = field(default_factory=lambda: dict(DEFAULT_REMOTE))
|
|
transfer: dict[str, Any] = field(default_factory=lambda: dict(DEFAULT_TRANSFER))
|
|
keep_on_failure: bool = False
|
|
|
|
# Load defaults, apply overrides, and return a PipelineConfig instance
|
|
@classmethod
|
|
def load(cls, project_root: Path, override_path: Path | None = None) -> "PipelineConfig":
|
|
load_dotenv(project_root / "pipeline" / ".env")
|
|
|
|
# Load default settings from vast.json
|
|
defaults_path = project_root / "pipeline" / "defaults" / "vast.json"
|
|
with open(defaults_path, encoding="utf-8") as handle:
|
|
raw = json.load(handle)
|
|
|
|
# Apply user-provided override config (deep merge keeps unspecified defaults)
|
|
if override_path is not None:
|
|
with open(override_path, encoding="utf-8") as handle:
|
|
raw = _deep_merge(raw, json.load(handle))
|
|
return cls(
|
|
search=raw.get("search", {}),
|
|
instance=raw.get("instance", {}),
|
|
remote=raw.get("remote", {}),
|
|
transfer=raw.get("transfer", {}),
|
|
keep_on_failure=raw.get("keep_on_failure", False),
|
|
)
|
|
|
|
# Convert the internal "search" config shape into Vast.ai query format
|
|
def build_offer_query(self, *, price_cap: float | None = None) -> dict[str, Any]:
|
|
s = self.search
|
|
query: dict[str, Any] = {
|
|
"limit": s["limit"],
|
|
"order": [[s["order_by"], s["order_direction"]]],
|
|
"type": s["offer_type"],
|
|
"verified": {"eq": s["verified_only"]},
|
|
"rentable": {"eq": s["rentable"]},
|
|
"rented": {"eq": s["rented"]},
|
|
"num_gpus": {"gte": s["num_gpus"]},
|
|
"reliability2": {"gte": s["min_reliability"]},
|
|
"gpu_ram": {"gte": s["min_cuda_ram_mb"]},
|
|
"cpu_ram": {"gte": s["min_cpu_ram_mb"]},
|
|
"disk_space": {"gte": s["min_disk_space_gb"]},
|
|
"direct_port_count": {"gte": s["min_direct_port_count"]},
|
|
}
|
|
# CLI-provided price cap takes precedence over config default
|
|
max_price = price_cap or s.get("max_dph_total")
|
|
if max_price is not None:
|
|
query["dph_total"] = {"lte": max_price}
|
|
gpu_names = s.get("gpu_names", [])
|
|
if gpu_names:
|
|
query["gpu_name"] = {"in": gpu_names}
|
|
excluded = s.get("countries_exclude", [])
|
|
if excluded:
|
|
query["geolocation"] = {"notin": excluded}
|
|
return query
|
|
|
|
# Resolve the local directory where training outputs are saved
|
|
def local_output_path(self, project_root: Path) -> Path:
|
|
return (project_root / self.transfer["local_output_dir"]).resolve()
|
|
|
|
# Resolve the local directory where datasets are stored
|
|
def local_data_path(self, project_root: Path) -> Path:
|
|
return (project_root / self.transfer["local_data_dir"]).resolve()
|
|
|
|
|
|
# Read the Vast.ai API key from the environment, failing fast if unset
|
|
def resolve_api_key() -> str:
|
|
api_key = os.environ.get("VAST_API_KEY")
|
|
if not api_key:
|
|
raise RuntimeError("VAST_API_KEY is not set.")
|
|
return api_key
|
|
|
|
|
|
# Resolve the local SSH private key path, defaulting to ~/.ssh/id_ed25519
|
|
def resolve_ssh_private_key() -> Path:
|
|
raw = os.environ.get("VAST_SSH_PRIVATE_KEY", "~/.ssh/id_ed25519")
|
|
path = Path(raw).expanduser()
|
|
if not path.exists():
|
|
raise RuntimeError(
|
|
f"SSH private key not found at {path}. Set VAST_SSH_PRIVATE_KEY to the correct path."
|
|
)
|
|
return path
|
|
|
|
|
|
# Read the corresponding .pub file for a given private key
|
|
def read_public_key(private_key_path: Path) -> str:
|
|
public_key_path = Path(f"{private_key_path}.pub")
|
|
if not public_key_path.exists():
|
|
raise RuntimeError(
|
|
f"SSH public key not found at {public_key_path}. Generate it or set VAST_SSH_PRIVATE_KEY."
|
|
)
|
|
return public_key_path.read_text(encoding="utf-8").strip()
|