Files
DRL_PROJ/pipeline/cli.py
T
2026-04-30 03:21:49 +01:00

131 lines
6.9 KiB
Python

import argparse
from pathlib import Path
from pipeline.orchestrator import EphemeralVastRunner, RunOptions
# Accept one or more config files, or a single directory (all *.json inside, sorted)
def _resolve_configs(raw: list[str]) -> list[Path]:
if len(raw) == 1 and Path(raw[0]).is_dir():
configs = sorted(p for p in Path(raw[0]).glob("*.json") if not p.name.startswith("_"))
if not configs:
raise ValueError(f"No JSON configs found in directory: {raw[0]}")
return configs
return [Path(p) for p in raw]
# Build the argparse CLI with subcommands: offers, run, up, status, down
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="Ephemeral Vast.ai training pipeline.")
subparsers = parser.add_subparsers(dest="command", required=True)
# Shared offer-related flags used by both "offers" and "run" subcommands
def add_offer_options(p: argparse.ArgumentParser) -> None:
p.add_argument("--pipeline-config", default=None, help="Optional JSON file with pipeline overrides.")
p.add_argument("--sort", choices=["price", "performance", "dlp_per_dollar"], default=None, help="Override offer ranking mode.")
p.add_argument("--region", default=None, help="Filter offers by region (e.g. 'europe', 'Portugal').")
p.add_argument("--price", type=float, default=None, help="Override max hourly price cap in USD.")
p.add_argument("--select-offer", action="store_true", help="Interactively choose an offer.")
# Shared template-related flags used by "run" and "up" subcommands
def add_template_options(p: argparse.ArgumentParser) -> None:
p.add_argument("--select-template", action="store_true", help="Interactively choose a Vast.ai template.")
p.add_argument("--template", default=None, help="Template hash ID to use for instance creation.")
# ── offers ──────────────────────────────────────────────────────────
offers_parser = subparsers.add_parser("offers", help="Inspect available Vast offers.")
add_offer_options(offers_parser)
offers_parser.add_argument("--list-regions", action="store_true", help="List matching regions and exit.")
offers_parser.add_argument("--limit-output", type=int, default=10, help="How many offers to print (default: 10).")
# ── run ─────────────────────────────────────────────────────────────
run_parser = subparsers.add_parser("run", help="Create instance, train, fetch outputs, destroy it.")
run_parser.add_argument("configs", nargs="+", help="One or more config paths, or a directory of JSON configs.")
add_offer_options(run_parser)
add_template_options(run_parser)
run_parser.add_argument("--download-data", action="store_true", help="Force download dataset via HuggingFace on the remote before training (auto-downloads if missing).")
run_parser.add_argument("--send-cropped", action="store_true", help="Rsync local cropped/ subdirectory to remote before training (sends only classifier or generator based on config).")
run_parser.add_argument("--keep-on-failure", action="store_true", help="Keep instance on failure.")
run_parser.add_argument("--dry-run", action="store_true", help="Search and print offers without creating instance.")
run_parser.add_argument("--no-gpu", action="store_true", help="Disable GPU training on remote (use CPU instead).")
# ── up ──────────────────────────────────────────────────────────────
up_parser = subparsers.add_parser("up", help="Create instance and print SSH details.")
up_parser.add_argument("--pipeline-config", default=None, help="Optional JSON file with pipeline overrides.")
add_template_options(up_parser)
up_parser.add_argument("--label", default=None, help="Optional label for the instance.")
# ── status ──────────────────────────────────────────────────────────
status_parser = subparsers.add_parser("status", help="Show instance details.")
status_parser.add_argument("instance_id", type=int)
status_parser.add_argument("--pipeline-config", default=None, help="Optional JSON file with pipeline overrides.")
# ── down ────────────────────────────────────────────────────────────
down_parser = subparsers.add_parser("down", help="Destroy an instance.")
down_parser.add_argument("instance_id", type=int)
down_parser.add_argument("--pipeline-config", default=None, help="Optional JSON file with pipeline overrides.")
return parser
# Parse CLI args and dispatch to the appropriate runner method
def main(argv=None) -> int:
parser = build_parser()
args = parser.parse_args(argv)
override = Path(args.pipeline_config) if getattr(args, "pipeline_config", None) else None
runner = EphemeralVastRunner(override)
# ── dispatch ────────────────────────────────────────────────────────
if args.command == "offers":
runner.offers(
sort_mode=args.sort,
region=args.region,
price_cap=args.price,
select_offer=args.select_offer,
list_regions=args.list_regions,
limit_output=args.limit_output,
)
return 0
if args.command == "run":
opts = RunOptions(
download_data=args.download_data,
send_cropped=args.send_cropped,
keep_on_failure=args.keep_on_failure,
dry_run=args.dry_run,
select_offer=args.select_offer,
select_template=args.select_template,
template_hash=args.template,
sort_mode=args.sort,
region=args.region,
price_cap=args.price,
use_gpu=not args.no_gpu, # Default to True, disable with --no-gpu
)
runner.run(_resolve_configs(args.configs), opts)
return 0
if args.command == "up":
up_opts = RunOptions(
select_template=getattr(args, "select_template", False),
template_hash=getattr(args, "template", None),
)
runner.up(label=args.label, opts=up_opts)
return 0
if args.command == "status":
runner.status(args.instance_id)
return 0
if args.command == "down":
runner.down(args.instance_id)
return 0
parser.error(f"Unknown command: {args.command}")
return 2