import argparse from pathlib import Path from pipeline.orchestrator import EphemeralVastRunner, RunOptions # Accept one or more config files, or a single directory (all *.json inside, sorted) def _resolve_configs(raw: list[str]) -> list[Path]: if len(raw) == 1 and Path(raw[0]).is_dir(): configs = sorted(p for p in Path(raw[0]).glob("*.json") if not p.name.startswith("_")) if not configs: raise ValueError(f"No JSON configs found in directory: {raw[0]}") return configs return [Path(p) for p in raw] # Build the argparse CLI with subcommands: offers, run, up, status, down def build_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser(description="Ephemeral Vast.ai training pipeline.") subparsers = parser.add_subparsers(dest="command", required=True) # Shared offer-related flags used by both "offers" and "run" subcommands def add_offer_options(p: argparse.ArgumentParser) -> None: p.add_argument("--pipeline-config", default=None, help="Optional JSON file with pipeline overrides.") p.add_argument("--sort", choices=["price", "performance", "dlp_per_dollar"], default=None, help="Override offer ranking mode.") p.add_argument("--region", default=None, help="Filter offers by region (e.g. 'europe', 'Portugal').") p.add_argument("--price", type=float, default=None, help="Override max hourly price cap in USD.") p.add_argument("--select-offer", action="store_true", help="Interactively choose an offer.") # Shared template-related flags used by "run" and "up" subcommands def add_template_options(p: argparse.ArgumentParser) -> None: p.add_argument("--select-template", action="store_true", help="Interactively choose a Vast.ai template.") p.add_argument("--template", default=None, help="Template hash ID to use for instance creation.") # ── offers ────────────────────────────────────────────────────────── offers_parser = subparsers.add_parser("offers", help="Inspect available Vast offers.") add_offer_options(offers_parser) offers_parser.add_argument("--list-regions", action="store_true", help="List matching regions and exit.") offers_parser.add_argument("--limit-output", type=int, default=10, help="How many offers to print (default: 10).") # ── run ───────────────────────────────────────────────────────────── run_parser = subparsers.add_parser("run", help="Create instance, train, fetch outputs, destroy it.") run_parser.add_argument("configs", nargs="+", help="One or more config paths, or a directory of JSON configs.") add_offer_options(run_parser) add_template_options(run_parser) run_parser.add_argument("--download-data", action="store_true", help="Force download dataset via HuggingFace on the remote before training (auto-downloads if missing).") run_parser.add_argument("--send-cropped", action="store_true", help="Rsync local cropped/ subdirectory to remote before training (sends only classifier or generator based on config).") run_parser.add_argument("--keep-on-failure", action="store_true", help="Keep instance on failure.") run_parser.add_argument("--dry-run", action="store_true", help="Search and print offers without creating instance.") run_parser.add_argument("--no-gpu", action="store_true", help="Disable GPU training on remote (use CPU instead).") # ── up ────────────────────────────────────────────────────────────── up_parser = subparsers.add_parser("up", help="Create instance and print SSH details.") up_parser.add_argument("--pipeline-config", default=None, help="Optional JSON file with pipeline overrides.") add_template_options(up_parser) up_parser.add_argument("--label", default=None, help="Optional label for the instance.") # ── status ────────────────────────────────────────────────────────── status_parser = subparsers.add_parser("status", help="Show instance details.") status_parser.add_argument("instance_id", type=int) status_parser.add_argument("--pipeline-config", default=None, help="Optional JSON file with pipeline overrides.") # ── down ──────────────────────────────────────────────────────────── down_parser = subparsers.add_parser("down", help="Destroy an instance.") down_parser.add_argument("instance_id", type=int) down_parser.add_argument("--pipeline-config", default=None, help="Optional JSON file with pipeline overrides.") return parser # Parse CLI args and dispatch to the appropriate runner method def main(argv=None) -> int: parser = build_parser() args = parser.parse_args(argv) override = Path(args.pipeline_config) if getattr(args, "pipeline_config", None) else None runner = EphemeralVastRunner(override) # ── dispatch ──────────────────────────────────────────────────────── if args.command == "offers": runner.offers( sort_mode=args.sort, region=args.region, price_cap=args.price, select_offer=args.select_offer, list_regions=args.list_regions, limit_output=args.limit_output, ) return 0 if args.command == "run": opts = RunOptions( download_data=args.download_data, send_cropped=args.send_cropped, keep_on_failure=args.keep_on_failure, dry_run=args.dry_run, select_offer=args.select_offer, select_template=args.select_template, template_hash=args.template, sort_mode=args.sort, region=args.region, price_cap=args.price, use_gpu=not args.no_gpu, # Default to True, disable with --no-gpu ) runner.run(_resolve_configs(args.configs), opts) return 0 if args.command == "up": up_opts = RunOptions( select_template=getattr(args, "select_template", False), template_hash=getattr(args, "template", None), ) runner.up(label=args.label, opts=up_opts) return 0 if args.command == "status": runner.status(args.instance_id) return 0 if args.command == "down": runner.down(args.instance_id) return 0 parser.error(f"Unknown command: {args.command}") return 2