131 lines
6.9 KiB
Python
131 lines
6.9 KiB
Python
import argparse
|
|
from pathlib import Path
|
|
|
|
from pipeline.orchestrator import EphemeralVastRunner, RunOptions
|
|
|
|
|
|
# Accept one or more config files, or a single directory (all *.json inside, sorted)
|
|
def _resolve_configs(raw: list[str]) -> list[Path]:
|
|
if len(raw) == 1 and Path(raw[0]).is_dir():
|
|
configs = sorted(p for p in Path(raw[0]).glob("*.json") if not p.name.startswith("_"))
|
|
if not configs:
|
|
raise ValueError(f"No JSON configs found in directory: {raw[0]}")
|
|
return configs
|
|
return [Path(p) for p in raw]
|
|
|
|
|
|
# Build the argparse CLI with subcommands: offers, run, up, status, down
|
|
def build_parser() -> argparse.ArgumentParser:
|
|
parser = argparse.ArgumentParser(description="Ephemeral Vast.ai training pipeline.")
|
|
subparsers = parser.add_subparsers(dest="command", required=True)
|
|
|
|
# Shared offer-related flags used by both "offers" and "run" subcommands
|
|
def add_offer_options(p: argparse.ArgumentParser) -> None:
|
|
p.add_argument("--pipeline-config", default=None, help="Optional JSON file with pipeline overrides.")
|
|
p.add_argument("--sort", choices=["price", "performance", "dlp_per_dollar"], default=None, help="Override offer ranking mode.")
|
|
p.add_argument("--region", default=None, help="Filter offers by region (e.g. 'europe', 'Portugal').")
|
|
p.add_argument("--price", type=float, default=None, help="Override max hourly price cap in USD.")
|
|
p.add_argument("--select-offer", action="store_true", help="Interactively choose an offer.")
|
|
|
|
# Shared template-related flags used by "run" and "up" subcommands
|
|
def add_template_options(p: argparse.ArgumentParser) -> None:
|
|
p.add_argument("--select-template", action="store_true", help="Interactively choose a Vast.ai template.")
|
|
p.add_argument("--template", default=None, help="Template hash ID to use for instance creation.")
|
|
|
|
# ── offers ──────────────────────────────────────────────────────────
|
|
|
|
offers_parser = subparsers.add_parser("offers", help="Inspect available Vast offers.")
|
|
add_offer_options(offers_parser)
|
|
offers_parser.add_argument("--list-regions", action="store_true", help="List matching regions and exit.")
|
|
offers_parser.add_argument("--limit-output", type=int, default=10, help="How many offers to print (default: 10).")
|
|
|
|
# ── run ─────────────────────────────────────────────────────────────
|
|
|
|
run_parser = subparsers.add_parser("run", help="Create instance, train, fetch outputs, destroy it.")
|
|
run_parser.add_argument("configs", nargs="+", help="One or more config paths, or a directory of JSON configs.")
|
|
add_offer_options(run_parser)
|
|
add_template_options(run_parser)
|
|
run_parser.add_argument("--download-data", action="store_true", help="Force download dataset via HuggingFace on the remote before training (auto-downloads if missing).")
|
|
run_parser.add_argument("--send-cropped", action="store_true", help="Rsync local cropped/ subdirectory to remote before training (sends only classifier or generator based on config).")
|
|
run_parser.add_argument("--keep-on-failure", action="store_true", help="Keep instance on failure.")
|
|
run_parser.add_argument("--dry-run", action="store_true", help="Search and print offers without creating instance.")
|
|
run_parser.add_argument("--no-gpu", action="store_true", help="Disable GPU training on remote (use CPU instead).")
|
|
|
|
# ── up ──────────────────────────────────────────────────────────────
|
|
|
|
up_parser = subparsers.add_parser("up", help="Create instance and print SSH details.")
|
|
up_parser.add_argument("--pipeline-config", default=None, help="Optional JSON file with pipeline overrides.")
|
|
add_template_options(up_parser)
|
|
up_parser.add_argument("--label", default=None, help="Optional label for the instance.")
|
|
|
|
# ── status ──────────────────────────────────────────────────────────
|
|
|
|
status_parser = subparsers.add_parser("status", help="Show instance details.")
|
|
status_parser.add_argument("instance_id", type=int)
|
|
status_parser.add_argument("--pipeline-config", default=None, help="Optional JSON file with pipeline overrides.")
|
|
|
|
# ── down ────────────────────────────────────────────────────────────
|
|
|
|
down_parser = subparsers.add_parser("down", help="Destroy an instance.")
|
|
down_parser.add_argument("instance_id", type=int)
|
|
down_parser.add_argument("--pipeline-config", default=None, help="Optional JSON file with pipeline overrides.")
|
|
|
|
return parser
|
|
|
|
|
|
# Parse CLI args and dispatch to the appropriate runner method
|
|
def main(argv=None) -> int:
|
|
parser = build_parser()
|
|
args = parser.parse_args(argv)
|
|
override = Path(args.pipeline_config) if getattr(args, "pipeline_config", None) else None
|
|
runner = EphemeralVastRunner(override)
|
|
|
|
# ── dispatch ────────────────────────────────────────────────────────
|
|
|
|
if args.command == "offers":
|
|
runner.offers(
|
|
sort_mode=args.sort,
|
|
region=args.region,
|
|
price_cap=args.price,
|
|
select_offer=args.select_offer,
|
|
list_regions=args.list_regions,
|
|
limit_output=args.limit_output,
|
|
)
|
|
return 0
|
|
|
|
if args.command == "run":
|
|
opts = RunOptions(
|
|
download_data=args.download_data,
|
|
send_cropped=args.send_cropped,
|
|
keep_on_failure=args.keep_on_failure,
|
|
dry_run=args.dry_run,
|
|
select_offer=args.select_offer,
|
|
select_template=args.select_template,
|
|
template_hash=args.template,
|
|
sort_mode=args.sort,
|
|
region=args.region,
|
|
price_cap=args.price,
|
|
use_gpu=not args.no_gpu, # Default to True, disable with --no-gpu
|
|
)
|
|
runner.run(_resolve_configs(args.configs), opts)
|
|
return 0
|
|
|
|
if args.command == "up":
|
|
up_opts = RunOptions(
|
|
select_template=getattr(args, "select_template", False),
|
|
template_hash=getattr(args, "template", None),
|
|
)
|
|
runner.up(label=args.label, opts=up_opts)
|
|
return 0
|
|
|
|
if args.command == "status":
|
|
runner.status(args.instance_id)
|
|
return 0
|
|
|
|
if args.command == "down":
|
|
runner.down(args.instance_id)
|
|
return 0
|
|
|
|
parser.error(f"Unknown command: {args.command}")
|
|
return 2
|