Preview of phase 2-5 implementation; needs a full check

This commit is contained in:
Johnny Fernandes
2026-04-30 13:10:33 +01:00
parent 6e32001ebc
commit 7417267117
35 changed files with 3605 additions and 115 deletions
+27 -2
View File
@@ -32,7 +32,7 @@ def main(config_path, *, data_dir_override=None, output_root="generator/outputs"
import torch
from src.data import GeneratorDataset, get_transform
from src.models import get_model
from src.training import train_dcgan
from src.training import train_dcgan, train_wgan, train_vae, train_ddpm
from src.utils import load_config
cfg = load_config(config_path)
@@ -50,6 +50,13 @@ def main(config_path, *, data_dir_override=None, output_root="generator/outputs"
model, kind = get_model(cfg)
# Count total trainable parameters
if isinstance(model, tuple):
n_params = sum(p.numel() for p in model[0].parameters() if p.requires_grad)
else:
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Trainable params: {n_params:,}")
augment = cfg.get("augment", True)
transform = get_transform(cfg.get("image_size", 128), augment=augment)
dataset = GeneratorDataset(
@@ -66,13 +73,31 @@ def main(config_path, *, data_dir_override=None, output_root="generator/outputs"
generator, discriminator, dataset, cfg,
save_dir=models_dir, run_name=run_name, device=device,
)
elif kind == "wgan":
generator, critic = model
history = train_wgan(
generator, critic, dataset, cfg,
save_dir=models_dir, run_name=run_name, device=device,
)
elif kind == "vae":
history = train_vae(
model, dataset, cfg,
save_dir=models_dir, run_name=run_name, device=device,
)
elif kind == "ddpm":
history = train_ddpm(
model, dataset, cfg,
save_dir=models_dir, run_name=run_name, device=device,
)
else:
raise NotImplementedError(f"kind={kind!r} not yet implemented in this phase")
logs_dir.mkdir(parents=True, exist_ok=True)
out = logs_dir / f"{run_name}.json"
log_data = {"run_name": run_name, "config": cfg, "history": history}
log_data["n_params"] = n_params
with open(out, "w") as f:
json.dump({"run_name": run_name, "config": cfg, "history": history}, f, indent=2)
json.dump(log_data, f, indent=2)
print(f"\nSaved log to {out}")