107 lines
3.6 KiB
Python
107 lines
3.6 KiB
Python
"""
|
|
Train a generative model from a config file.
|
|
|
|
Usage:
|
|
python run.py <config.json>
|
|
python run.py <config.json> --data-dir /path/to/data --output-root generator/outputs
|
|
"""
|
|
import argparse
|
|
import json
|
|
import sys
|
|
import warnings
|
|
from pathlib import Path
|
|
|
|
# Allow running from project root (python3 generator/run.py ...) or from inside generator/
|
|
_here = Path(__file__).resolve().parent
|
|
if str(_here) not in sys.path:
|
|
sys.path.insert(0, str(_here))
|
|
|
|
warnings.filterwarnings("ignore", message="Corrupt EXIF data", category=UserWarning)
|
|
|
|
|
|
def parse_args(argv=None):
|
|
parser = argparse.ArgumentParser(description=__doc__)
|
|
parser.add_argument("config_path")
|
|
parser.add_argument("--data-dir", default=None)
|
|
parser.add_argument("--output-root", default="generator/outputs")
|
|
parser.add_argument("--use-gpu", action="store_true", help="Accepted for pipeline compatibility (GPU auto-detected).")
|
|
return parser.parse_args(argv)
|
|
|
|
|
|
def main(config_path, *, data_dir_override=None, output_root="generator/outputs"):
|
|
import torch
|
|
from src.data import GeneratorDataset, get_transform
|
|
from src.models import get_model
|
|
from src.training import train_dcgan, train_wgan, train_vae, train_ddpm
|
|
from src.utils import load_config
|
|
|
|
cfg = load_config(config_path)
|
|
|
|
run_name = cfg.get("run_name", Path(config_path).stem)
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
data_dir = data_dir_override or cfg.get("data_dir", "data")
|
|
output_root = Path(output_root)
|
|
models_dir = output_root / "models"
|
|
logs_dir = output_root / "logs"
|
|
|
|
print(f"Run: {run_name}")
|
|
print(f"Config: {cfg}")
|
|
print(f"Device: {device} Data: {data_dir}")
|
|
|
|
model, kind = get_model(cfg)
|
|
|
|
# Count total trainable parameters
|
|
if isinstance(model, tuple):
|
|
n_params = sum(p.numel() for p in model[0].parameters() if p.requires_grad)
|
|
else:
|
|
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
print(f"Trainable params: {n_params:,}")
|
|
|
|
augment = cfg.get("augment", True)
|
|
transform = get_transform(cfg.get("image_size", 128), augment=augment)
|
|
dataset = GeneratorDataset(
|
|
data_dir,
|
|
sources=cfg.get("sources", ["wiki"]),
|
|
subsample=cfg.get("subsample", 1.0),
|
|
transform=transform,
|
|
)
|
|
print(f"Dataset size: {len(dataset)}")
|
|
|
|
if kind == "dcgan":
|
|
generator, discriminator = model
|
|
history = train_dcgan(
|
|
generator, discriminator, dataset, cfg,
|
|
save_dir=models_dir, run_name=run_name, device=device,
|
|
)
|
|
elif kind == "wgan":
|
|
generator, critic = model
|
|
history = train_wgan(
|
|
generator, critic, dataset, cfg,
|
|
save_dir=models_dir, run_name=run_name, device=device,
|
|
)
|
|
elif kind == "vae":
|
|
history = train_vae(
|
|
model, dataset, cfg,
|
|
save_dir=models_dir, run_name=run_name, device=device,
|
|
)
|
|
elif kind == "ddpm":
|
|
history = train_ddpm(
|
|
model, dataset, cfg,
|
|
save_dir=models_dir, run_name=run_name, device=device,
|
|
)
|
|
else:
|
|
raise NotImplementedError(f"kind={kind!r} not yet implemented in this phase")
|
|
|
|
logs_dir.mkdir(parents=True, exist_ok=True)
|
|
out = logs_dir / f"{run_name}.json"
|
|
log_data = {"run_name": run_name, "config": cfg, "history": history}
|
|
log_data["n_params"] = n_params
|
|
with open(out, "w") as f:
|
|
json.dump(log_data, f, indent=2)
|
|
print(f"\nSaved log to {out}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
args = parse_args(sys.argv[1:])
|
|
main(args.config_path, data_dir_override=args.data_dir, output_root=args.output_root)
|