Files
DRL_PROJ/generator/run.py
T
2026-04-30 13:10:33 +01:00

107 lines
3.6 KiB
Python

"""
Train a generative model from a config file.
Usage:
python run.py <config.json>
python run.py <config.json> --data-dir /path/to/data --output-root generator/outputs
"""
import argparse
import json
import sys
import warnings
from pathlib import Path
# Allow running from project root (python3 generator/run.py ...) or from inside generator/
_here = Path(__file__).resolve().parent
if str(_here) not in sys.path:
sys.path.insert(0, str(_here))
warnings.filterwarnings("ignore", message="Corrupt EXIF data", category=UserWarning)
def parse_args(argv=None):
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("config_path")
parser.add_argument("--data-dir", default=None)
parser.add_argument("--output-root", default="generator/outputs")
parser.add_argument("--use-gpu", action="store_true", help="Accepted for pipeline compatibility (GPU auto-detected).")
return parser.parse_args(argv)
def main(config_path, *, data_dir_override=None, output_root="generator/outputs"):
import torch
from src.data import GeneratorDataset, get_transform
from src.models import get_model
from src.training import train_dcgan, train_wgan, train_vae, train_ddpm
from src.utils import load_config
cfg = load_config(config_path)
run_name = cfg.get("run_name", Path(config_path).stem)
device = "cuda" if torch.cuda.is_available() else "cpu"
data_dir = data_dir_override or cfg.get("data_dir", "data")
output_root = Path(output_root)
models_dir = output_root / "models"
logs_dir = output_root / "logs"
print(f"Run: {run_name}")
print(f"Config: {cfg}")
print(f"Device: {device} Data: {data_dir}")
model, kind = get_model(cfg)
# Count total trainable parameters
if isinstance(model, tuple):
n_params = sum(p.numel() for p in model[0].parameters() if p.requires_grad)
else:
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Trainable params: {n_params:,}")
augment = cfg.get("augment", True)
transform = get_transform(cfg.get("image_size", 128), augment=augment)
dataset = GeneratorDataset(
data_dir,
sources=cfg.get("sources", ["wiki"]),
subsample=cfg.get("subsample", 1.0),
transform=transform,
)
print(f"Dataset size: {len(dataset)}")
if kind == "dcgan":
generator, discriminator = model
history = train_dcgan(
generator, discriminator, dataset, cfg,
save_dir=models_dir, run_name=run_name, device=device,
)
elif kind == "wgan":
generator, critic = model
history = train_wgan(
generator, critic, dataset, cfg,
save_dir=models_dir, run_name=run_name, device=device,
)
elif kind == "vae":
history = train_vae(
model, dataset, cfg,
save_dir=models_dir, run_name=run_name, device=device,
)
elif kind == "ddpm":
history = train_ddpm(
model, dataset, cfg,
save_dir=models_dir, run_name=run_name, device=device,
)
else:
raise NotImplementedError(f"kind={kind!r} not yet implemented in this phase")
logs_dir.mkdir(parents=True, exist_ok=True)
out = logs_dir / f"{run_name}.json"
log_data = {"run_name": run_name, "config": cfg, "history": history}
log_data["n_params"] = n_params
with open(out, "w") as f:
json.dump(log_data, f, indent=2)
print(f"\nSaved log to {out}")
if __name__ == "__main__":
args = parse_args(sys.argv[1:])
main(args.config_path, data_dir_override=args.data_dir, output_root=args.output_root)