Files
DRL_PROJ/pipeline/config.py
T
Johnny Fernandes bb3dfb92d5 Clean state
2026-04-30 01:25:39 +01:00

183 lines
6.6 KiB
Python

import json
import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
# Load environment variables from pipeline/.env if present
# This file is intentionally optional and gitignored
# Expected keys:
# VAST_API_KEY=<your-api-key>
# VAST_SSH_PRIVATE_KEY=/home/your-user/.ssh/id_ed25519
def load_dotenv(dotenv_path: Path) -> None:
if not dotenv_path.exists():
return
for raw_line in dotenv_path.read_text(encoding="utf-8").splitlines():
line = raw_line.strip()
if not line or line.startswith("#"):
continue
if "=" not in line:
continue
key, value = line.split("=", 1)
key = key.strip()
value = value.strip().strip("'").strip('"')
os.environ.setdefault(key, value)
# Default offer search settings
DEFAULT_SEARCH: dict[str, Any] = {
"limit": 100,
"order_by": "dlperf",
"order_direction": "desc",
"sort_mode": "performance",
"offer_type": "ondemand",
"verified_only": True,
"rentable": True,
"rented": False,
"num_gpus": 1,
"min_reliability": 0.98,
"min_cuda_ram_mb": 0,
"min_cpu_ram_mb": 0,
"min_disk_space_gb": 0,
"min_direct_port_count": 1,
"max_dph_total": 0.20,
"gpu_names": ["RTX 3090", "RTX 3090 Ti"],
"countries_exclude": [],
}
# Default instance settings
DEFAULT_INSTANCE: dict[str, Any] = {
"image": "vastai/pytorch:latest",
"disk_gb": 48,
"target_state": "running",
"label_prefix": "drl-proj",
"template_hash_id": None,
}
# Default remote SSH/workspace settings
DEFAULT_REMOTE: dict[str, Any] = {
"ssh_user": "root",
"workspace_dir": "/workspace/DRL_PROJ",
"remote_data_dir": "/workspace/data/DeepFakeFace",
"remote_output_root": "classifier/outputs",
"ssh_timeout_seconds": 900,
"poll_interval_seconds": 10,
}
# Default local transfer paths
DEFAULT_TRANSFER: dict[str, Any] = {
"local_output_dir": "classifier/outputs",
"local_data_dir": "data",
}
# Merge nested dictionaries recursively
# Used when a user override file should replace only specific nested keys
# (for example, overriding just search.max_dph_total without redefining search)
def _deep_merge(base: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]:
result = dict(base)
for key, value in override.items():
if isinstance(value, dict) and isinstance(result.get(key), dict):
result[key] = _deep_merge(result[key], value)
else:
result[key] = value
return result
# Runtime pipeline config container and loader
@dataclass(slots=True)
class PipelineConfig:
search: dict[str, Any] = field(default_factory=lambda: dict(DEFAULT_SEARCH))
instance: dict[str, Any] = field(default_factory=lambda: dict(DEFAULT_INSTANCE))
remote: dict[str, Any] = field(default_factory=lambda: dict(DEFAULT_REMOTE))
transfer: dict[str, Any] = field(default_factory=lambda: dict(DEFAULT_TRANSFER))
keep_on_failure: bool = False
# Load defaults, apply overrides, and return a PipelineConfig instance
@classmethod
def load(cls, project_root: Path, override_path: Path | None = None) -> "PipelineConfig":
load_dotenv(project_root / "pipeline" / ".env")
# Load default settings from vast.json
defaults_path = project_root / "pipeline" / "defaults" / "vast.json"
with open(defaults_path, encoding="utf-8") as handle:
raw = json.load(handle)
# Apply user-provided override config (deep merge keeps unspecified defaults)
if override_path is not None:
with open(override_path, encoding="utf-8") as handle:
raw = _deep_merge(raw, json.load(handle))
return cls(
search=raw.get("search", {}),
instance=raw.get("instance", {}),
remote=raw.get("remote", {}),
transfer=raw.get("transfer", {}),
keep_on_failure=raw.get("keep_on_failure", False),
)
# Convert the internal "search" config shape into Vast.ai query format
def build_offer_query(self, *, price_cap: float | None = None) -> dict[str, Any]:
s = self.search
query: dict[str, Any] = {
"limit": s["limit"],
"order": [[s["order_by"], s["order_direction"]]],
"type": s["offer_type"],
"verified": {"eq": s["verified_only"]},
"rentable": {"eq": s["rentable"]},
"rented": {"eq": s["rented"]},
"num_gpus": {"gte": s["num_gpus"]},
"reliability2": {"gte": s["min_reliability"]},
"gpu_ram": {"gte": s["min_cuda_ram_mb"]},
"cpu_ram": {"gte": s["min_cpu_ram_mb"]},
"disk_space": {"gte": s["min_disk_space_gb"]},
"direct_port_count": {"gte": s["min_direct_port_count"]},
}
# CLI-provided price cap takes precedence over config default
max_price = price_cap or s.get("max_dph_total")
if max_price is not None:
query["dph_total"] = {"lte": max_price}
gpu_names = s.get("gpu_names", [])
if gpu_names:
query["gpu_name"] = {"in": gpu_names}
excluded = s.get("countries_exclude", [])
if excluded:
query["geolocation"] = {"notin": excluded}
return query
# Resolve the local directory where training outputs are saved
def local_output_path(self, project_root: Path) -> Path:
return (project_root / self.transfer["local_output_dir"]).resolve()
# Resolve the local directory where datasets are stored
def local_data_path(self, project_root: Path) -> Path:
return (project_root / self.transfer["local_data_dir"]).resolve()
# Read the Vast.ai API key from the environment, failing fast if unset
def resolve_api_key() -> str:
api_key = os.environ.get("VAST_API_KEY")
if not api_key:
raise RuntimeError("VAST_API_KEY is not set.")
return api_key
# Resolve the local SSH private key path, defaulting to ~/.ssh/id_ed25519
def resolve_ssh_private_key() -> Path:
raw = os.environ.get("VAST_SSH_PRIVATE_KEY", "~/.ssh/id_ed25519")
path = Path(raw).expanduser()
if not path.exists():
raise RuntimeError(
f"SSH private key not found at {path}. Set VAST_SSH_PRIVATE_KEY to the correct path."
)
return path
# Read the corresponding .pub file for a given private key
def read_public_key(private_key_path: Path) -> str:
public_key_path = Path(f"{private_key_path}.pub")
if not public_key_path.exists():
raise RuntimeError(
f"SSH public key not found at {public_key_path}. Generate it or set VAST_SSH_PRIVATE_KEY."
)
return public_key_path.read_text(encoding="utf-8").strip()