import json import os from dataclasses import dataclass, field from pathlib import Path from typing import Any # Load environment variables from pipeline/.env if present # This file is intentionally optional and gitignored # Expected keys: # VAST_API_KEY= # VAST_SSH_PRIVATE_KEY=/home/your-user/.ssh/id_ed25519 def load_dotenv(dotenv_path: Path) -> None: if not dotenv_path.exists(): return for raw_line in dotenv_path.read_text(encoding="utf-8").splitlines(): line = raw_line.strip() if not line or line.startswith("#"): continue if "=" not in line: continue key, value = line.split("=", 1) key = key.strip() value = value.strip().strip("'").strip('"') os.environ.setdefault(key, value) # Default offer search settings DEFAULT_SEARCH: dict[str, Any] = { "limit": 100, "order_by": "dlperf", "order_direction": "desc", "sort_mode": "performance", "offer_type": "ondemand", "verified_only": True, "rentable": True, "rented": False, "num_gpus": 1, "min_reliability": 0.98, "min_cuda_ram_mb": 0, "min_cpu_ram_mb": 0, "min_disk_space_gb": 0, "min_direct_port_count": 1, "max_dph_total": 0.20, "gpu_names": ["RTX 3090", "RTX 3090 Ti"], "countries_exclude": [], } # Default instance settings DEFAULT_INSTANCE: dict[str, Any] = { "image": "vastai/pytorch:latest", "disk_gb": 48, "target_state": "running", "label_prefix": "drl-proj", "template_hash_id": None, } # Default remote SSH/workspace settings DEFAULT_REMOTE: dict[str, Any] = { "ssh_user": "root", "workspace_dir": "/workspace/DRL_PROJ", "remote_data_dir": "/workspace/data/DeepFakeFace", "remote_output_root": "classifier/outputs", "ssh_timeout_seconds": 900, "poll_interval_seconds": 10, } # Default local transfer paths DEFAULT_TRANSFER: dict[str, Any] = { "local_output_dir": "classifier/outputs", "local_data_dir": "data", } # Merge nested dictionaries recursively # Used when a user override file should replace only specific nested keys # (for example, overriding just search.max_dph_total without redefining search) def _deep_merge(base: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]: result = dict(base) for key, value in override.items(): if isinstance(value, dict) and isinstance(result.get(key), dict): result[key] = _deep_merge(result[key], value) else: result[key] = value return result # Runtime pipeline config container and loader @dataclass(slots=True) class PipelineConfig: search: dict[str, Any] = field(default_factory=lambda: dict(DEFAULT_SEARCH)) instance: dict[str, Any] = field(default_factory=lambda: dict(DEFAULT_INSTANCE)) remote: dict[str, Any] = field(default_factory=lambda: dict(DEFAULT_REMOTE)) transfer: dict[str, Any] = field(default_factory=lambda: dict(DEFAULT_TRANSFER)) keep_on_failure: bool = False # Load defaults, apply overrides, and return a PipelineConfig instance @classmethod def load(cls, project_root: Path, override_path: Path | None = None) -> "PipelineConfig": load_dotenv(project_root / "pipeline" / ".env") # Load default settings from vast.json defaults_path = project_root / "pipeline" / "defaults" / "vast.json" with open(defaults_path, encoding="utf-8") as handle: raw = json.load(handle) # Apply user-provided override config (deep merge keeps unspecified defaults) if override_path is not None: with open(override_path, encoding="utf-8") as handle: raw = _deep_merge(raw, json.load(handle)) return cls( search=raw.get("search", {}), instance=raw.get("instance", {}), remote=raw.get("remote", {}), transfer=raw.get("transfer", {}), keep_on_failure=raw.get("keep_on_failure", False), ) # Convert the internal "search" config shape into Vast.ai query format def build_offer_query(self, *, price_cap: float | None = None) -> dict[str, Any]: s = self.search query: dict[str, Any] = { "limit": s["limit"], "order": [[s["order_by"], s["order_direction"]]], "type": s["offer_type"], "verified": {"eq": s["verified_only"]}, "rentable": {"eq": s["rentable"]}, "rented": {"eq": s["rented"]}, "num_gpus": {"gte": s["num_gpus"]}, "reliability2": {"gte": s["min_reliability"]}, "gpu_ram": {"gte": s["min_cuda_ram_mb"]}, "cpu_ram": {"gte": s["min_cpu_ram_mb"]}, "disk_space": {"gte": s["min_disk_space_gb"]}, "direct_port_count": {"gte": s["min_direct_port_count"]}, } # CLI-provided price cap takes precedence over config default max_price = price_cap or s.get("max_dph_total") if max_price is not None: query["dph_total"] = {"lte": max_price} gpu_names = s.get("gpu_names", []) if gpu_names: query["gpu_name"] = {"in": gpu_names} excluded = s.get("countries_exclude", []) if excluded: query["geolocation"] = {"notin": excluded} return query # Resolve the local directory where training outputs are saved def local_output_path(self, project_root: Path) -> Path: return (project_root / self.transfer["local_output_dir"]).resolve() # Resolve the local directory where datasets are stored def local_data_path(self, project_root: Path) -> Path: return (project_root / self.transfer["local_data_dir"]).resolve() # Read the Vast.ai API key from the environment, failing fast if unset def resolve_api_key() -> str: api_key = os.environ.get("VAST_API_KEY") if not api_key: raise RuntimeError("VAST_API_KEY is not set.") return api_key # Resolve the local SSH private key path, defaulting to ~/.ssh/id_ed25519 def resolve_ssh_private_key() -> Path: raw = os.environ.get("VAST_SSH_PRIVATE_KEY", "~/.ssh/id_ed25519") path = Path(raw).expanduser() if not path.exists(): raise RuntimeError( f"SSH private key not found at {path}. Set VAST_SSH_PRIVATE_KEY to the correct path." ) return path # Read the corresponding .pub file for a given private key def read_public_key(private_key_path: Path) -> str: public_key_path = Path(f"{private_key_path}.pub") if not public_key_path.exists(): raise RuntimeError( f"SSH public key not found at {public_key_path}. Generate it or set VAST_SSH_PRIVATE_KEY." ) return public_key_path.read_text(encoding="utf-8").strip()