""" Train a generative model from a config file. Usage: python run.py python run.py --data-dir /path/to/data --output-root generator/outputs """ import argparse import json import sys import warnings from pathlib import Path # Allow running from project root (python3 generator/run.py ...) or from inside generator/ _here = Path(__file__).resolve().parent if str(_here) not in sys.path: sys.path.insert(0, str(_here)) warnings.filterwarnings("ignore", message="Corrupt EXIF data", category=UserWarning) def parse_args(argv=None): parser = argparse.ArgumentParser(description=__doc__) parser.add_argument("config_path") parser.add_argument("--data-dir", default=None) parser.add_argument("--output-root", default="generator/outputs") parser.add_argument("--use-gpu", action="store_true", help="Accepted for pipeline compatibility (GPU auto-detected).") return parser.parse_args(argv) def main(config_path, *, data_dir_override=None, output_root="generator/outputs"): import torch from src.data import GeneratorDataset, get_transform from src.models import get_model from src.training import train_dcgan from src.utils import load_config cfg = load_config(config_path) run_name = cfg.get("run_name", Path(config_path).stem) device = "cuda" if torch.cuda.is_available() else "cpu" data_dir = data_dir_override or cfg.get("data_dir", "data") output_root = Path(output_root) models_dir = output_root / "models" logs_dir = output_root / "logs" print(f"Run: {run_name}") print(f"Config: {cfg}") print(f"Device: {device} Data: {data_dir}") model, kind = get_model(cfg) augment = cfg.get("augment", True) transform = get_transform(cfg.get("image_size", 128), augment=augment) dataset = GeneratorDataset( data_dir, sources=cfg.get("sources", ["wiki"]), subsample=cfg.get("subsample", 1.0), transform=transform, ) print(f"Dataset size: {len(dataset)}") if kind == "dcgan": generator, discriminator = model history = train_dcgan( generator, discriminator, dataset, cfg, save_dir=models_dir, run_name=run_name, device=device, ) else: raise NotImplementedError(f"kind={kind!r} not yet implemented in this phase") logs_dir.mkdir(parents=True, exist_ok=True) out = logs_dir / f"{run_name}.json" with open(out, "w") as f: json.dump({"run_name": run_name, "config": cfg, "history": history}, f, indent=2) print(f"\nSaved log to {out}") if __name__ == "__main__": args = parse_args(sys.argv[1:]) main(args.config_path, data_dir_override=args.data_dir, output_root=args.output_root)