Clean state
This commit is contained in:
@@ -0,0 +1,182 @@
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
# Load environment variables from pipeline/.env if present
|
||||
# This file is intentionally optional and gitignored
|
||||
# Expected keys:
|
||||
# VAST_API_KEY=<your-api-key>
|
||||
# VAST_SSH_PRIVATE_KEY=/home/your-user/.ssh/id_ed25519
|
||||
def load_dotenv(dotenv_path: Path) -> None:
|
||||
if not dotenv_path.exists():
|
||||
return
|
||||
for raw_line in dotenv_path.read_text(encoding="utf-8").splitlines():
|
||||
line = raw_line.strip()
|
||||
if not line or line.startswith("#"):
|
||||
continue
|
||||
if "=" not in line:
|
||||
continue
|
||||
key, value = line.split("=", 1)
|
||||
key = key.strip()
|
||||
value = value.strip().strip("'").strip('"')
|
||||
os.environ.setdefault(key, value)
|
||||
|
||||
# Default offer search settings
|
||||
DEFAULT_SEARCH: dict[str, Any] = {
|
||||
"limit": 100,
|
||||
"order_by": "dlperf",
|
||||
"order_direction": "desc",
|
||||
"sort_mode": "performance",
|
||||
"offer_type": "ondemand",
|
||||
"verified_only": True,
|
||||
"rentable": True,
|
||||
"rented": False,
|
||||
"num_gpus": 1,
|
||||
"min_reliability": 0.98,
|
||||
"min_cuda_ram_mb": 0,
|
||||
"min_cpu_ram_mb": 0,
|
||||
"min_disk_space_gb": 0,
|
||||
"min_direct_port_count": 1,
|
||||
"max_dph_total": 0.20,
|
||||
"gpu_names": ["RTX 3090", "RTX 3090 Ti"],
|
||||
"countries_exclude": [],
|
||||
}
|
||||
|
||||
# Default instance settings
|
||||
DEFAULT_INSTANCE: dict[str, Any] = {
|
||||
"image": "vastai/pytorch:latest",
|
||||
"disk_gb": 48,
|
||||
"target_state": "running",
|
||||
"label_prefix": "drl-proj",
|
||||
"template_hash_id": None,
|
||||
}
|
||||
|
||||
# Default remote SSH/workspace settings
|
||||
DEFAULT_REMOTE: dict[str, Any] = {
|
||||
"ssh_user": "root",
|
||||
"workspace_dir": "/workspace/DRL_PROJ",
|
||||
"remote_data_dir": "/workspace/data/DeepFakeFace",
|
||||
"remote_output_root": "classifier/outputs",
|
||||
"ssh_timeout_seconds": 900,
|
||||
"poll_interval_seconds": 10,
|
||||
}
|
||||
|
||||
# Default local transfer paths
|
||||
DEFAULT_TRANSFER: dict[str, Any] = {
|
||||
"local_output_dir": "classifier/outputs",
|
||||
"local_data_dir": "data",
|
||||
}
|
||||
|
||||
|
||||
# Merge nested dictionaries recursively
|
||||
# Used when a user override file should replace only specific nested keys
|
||||
# (for example, overriding just search.max_dph_total without redefining search)
|
||||
def _deep_merge(base: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]:
|
||||
result = dict(base)
|
||||
for key, value in override.items():
|
||||
if isinstance(value, dict) and isinstance(result.get(key), dict):
|
||||
result[key] = _deep_merge(result[key], value)
|
||||
else:
|
||||
result[key] = value
|
||||
return result
|
||||
|
||||
|
||||
# Runtime pipeline config container and loader
|
||||
@dataclass(slots=True)
|
||||
class PipelineConfig:
|
||||
search: dict[str, Any] = field(default_factory=lambda: dict(DEFAULT_SEARCH))
|
||||
instance: dict[str, Any] = field(default_factory=lambda: dict(DEFAULT_INSTANCE))
|
||||
remote: dict[str, Any] = field(default_factory=lambda: dict(DEFAULT_REMOTE))
|
||||
transfer: dict[str, Any] = field(default_factory=lambda: dict(DEFAULT_TRANSFER))
|
||||
keep_on_failure: bool = False
|
||||
|
||||
# Load defaults, apply overrides, and return a PipelineConfig instance
|
||||
@classmethod
|
||||
def load(cls, project_root: Path, override_path: Path | None = None) -> "PipelineConfig":
|
||||
load_dotenv(project_root / "pipeline" / ".env")
|
||||
|
||||
# Load default settings from vast.json
|
||||
defaults_path = project_root / "pipeline" / "defaults" / "vast.json"
|
||||
with open(defaults_path, encoding="utf-8") as handle:
|
||||
raw = json.load(handle)
|
||||
|
||||
# Apply user-provided override config (deep merge keeps unspecified defaults)
|
||||
if override_path is not None:
|
||||
with open(override_path, encoding="utf-8") as handle:
|
||||
raw = _deep_merge(raw, json.load(handle))
|
||||
return cls(
|
||||
search=raw.get("search", {}),
|
||||
instance=raw.get("instance", {}),
|
||||
remote=raw.get("remote", {}),
|
||||
transfer=raw.get("transfer", {}),
|
||||
keep_on_failure=raw.get("keep_on_failure", False),
|
||||
)
|
||||
|
||||
# Convert the internal "search" config shape into Vast.ai query format
|
||||
def build_offer_query(self, *, price_cap: float | None = None) -> dict[str, Any]:
|
||||
s = self.search
|
||||
query: dict[str, Any] = {
|
||||
"limit": s["limit"],
|
||||
"order": [[s["order_by"], s["order_direction"]]],
|
||||
"type": s["offer_type"],
|
||||
"verified": {"eq": s["verified_only"]},
|
||||
"rentable": {"eq": s["rentable"]},
|
||||
"rented": {"eq": s["rented"]},
|
||||
"num_gpus": {"gte": s["num_gpus"]},
|
||||
"reliability2": {"gte": s["min_reliability"]},
|
||||
"gpu_ram": {"gte": s["min_cuda_ram_mb"]},
|
||||
"cpu_ram": {"gte": s["min_cpu_ram_mb"]},
|
||||
"disk_space": {"gte": s["min_disk_space_gb"]},
|
||||
"direct_port_count": {"gte": s["min_direct_port_count"]},
|
||||
}
|
||||
# CLI-provided price cap takes precedence over config default
|
||||
max_price = price_cap or s.get("max_dph_total")
|
||||
if max_price is not None:
|
||||
query["dph_total"] = {"lte": max_price}
|
||||
gpu_names = s.get("gpu_names", [])
|
||||
if gpu_names:
|
||||
query["gpu_name"] = {"in": gpu_names}
|
||||
excluded = s.get("countries_exclude", [])
|
||||
if excluded:
|
||||
query["geolocation"] = {"notin": excluded}
|
||||
return query
|
||||
|
||||
# Resolve the local directory where training outputs are saved
|
||||
def local_output_path(self, project_root: Path) -> Path:
|
||||
return (project_root / self.transfer["local_output_dir"]).resolve()
|
||||
|
||||
# Resolve the local directory where datasets are stored
|
||||
def local_data_path(self, project_root: Path) -> Path:
|
||||
return (project_root / self.transfer["local_data_dir"]).resolve()
|
||||
|
||||
|
||||
# Read the Vast.ai API key from the environment, failing fast if unset
|
||||
def resolve_api_key() -> str:
|
||||
api_key = os.environ.get("VAST_API_KEY")
|
||||
if not api_key:
|
||||
raise RuntimeError("VAST_API_KEY is not set.")
|
||||
return api_key
|
||||
|
||||
|
||||
# Resolve the local SSH private key path, defaulting to ~/.ssh/id_ed25519
|
||||
def resolve_ssh_private_key() -> Path:
|
||||
raw = os.environ.get("VAST_SSH_PRIVATE_KEY", "~/.ssh/id_ed25519")
|
||||
path = Path(raw).expanduser()
|
||||
if not path.exists():
|
||||
raise RuntimeError(
|
||||
f"SSH private key not found at {path}. Set VAST_SSH_PRIVATE_KEY to the correct path."
|
||||
)
|
||||
return path
|
||||
|
||||
|
||||
# Read the corresponding .pub file for a given private key
|
||||
def read_public_key(private_key_path: Path) -> str:
|
||||
public_key_path = Path(f"{private_key_path}.pub")
|
||||
if not public_key_path.exists():
|
||||
raise RuntimeError(
|
||||
f"SSH public key not found at {public_key_path}. Generate it or set VAST_SSH_PRIVATE_KEY."
|
||||
)
|
||||
return public_key_path.read_text(encoding="utf-8").strip()
|
||||
Reference in New Issue
Block a user