From 02b1f7f16f12c8e986dcab1425c76fd706b36608 Mon Sep 17 00:00:00 2001 From: Johnny Fernandes Date: Sat, 2 May 2026 18:38:55 +0100 Subject: [PATCH] Testing VAE until it works - v1 --- .../configs/phase3/p3_3_vae_patchgan.json | 2 +- generator/run.py | 2 +- generator/src/training/trainer.py | 21 ++++++++++--------- 3 files changed, 13 insertions(+), 12 deletions(-) diff --git a/generator/configs/phase3/p3_3_vae_patchgan.json b/generator/configs/phase3/p3_3_vae_patchgan.json index 75d26ff..1043c47 100644 --- a/generator/configs/phase3/p3_3_vae_patchgan.json +++ b/generator/configs/phase3/p3_3_vae_patchgan.json @@ -5,6 +5,6 @@ "lr_d": 1e-4, "beta_kl": 0.25, "lambda_perceptual": 0.1, - "lambda_adversarial": 0.1, + "lambda_adversarial": 0.01, "ndf_patch": 64 } diff --git a/generator/run.py b/generator/run.py index f11727e..e9ed7a2 100644 --- a/generator/run.py +++ b/generator/run.py @@ -52,7 +52,7 @@ def main(config_path, *, data_dir_override=None, output_root="generator/outputs" # Count total trainable parameters if isinstance(model, tuple): - n_params = sum(p.numel() for p in model[0].parameters() if p.requires_grad) + n_params = sum(p.numel() for m in model for p in m.parameters() if p.requires_grad) else: n_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f"Trainable params: {n_params:,}") diff --git a/generator/src/training/trainer.py b/generator/src/training/trainer.py index ce08a6d..3f3b7f3 100644 --- a/generator/src/training/trainer.py +++ b/generator/src/training/trainer.py @@ -393,7 +393,6 @@ def _save_vae_samples( # Interleave real / reconstruction pairs pairs = torch.stack([real, recon], dim=1).flatten(0, 1) save_image(pairs, samples_dir / f"epoch_{epoch:04d}_recon.png", nrow=4) - vae.train() def train_vae( @@ -549,18 +548,20 @@ def train_vae( # ── PatchGAN discriminator step ─────────────────────────────── adv_d = real.new_zeros(1).squeeze() if use_adversarial: - opt_d.zero_grad() - d_real = patchgan(real) - d_fake = patchgan(recon.detach()) - adv_d = hinge_d_loss(d_real, d_fake) - if torch.isfinite(adv_d): - adv_d.backward() - torch.nn.utils.clip_grad_norm_(patchgan.parameters(), grad_clip) - opt_d.step() + # Warmup: only start adversarial training after 20% of epochs + if epoch > kl_warmup_epochs: + opt_d.zero_grad() + d_real = patchgan(real) + d_fake = patchgan(recon.detach()) + adv_d = hinge_d_loss(d_real, d_fake) + if torch.isfinite(adv_d): + adv_d.backward() + torch.nn.utils.clip_grad_norm_(patchgan.parameters(), grad_clip) + opt_d.step() # ── PatchGAN generator adversarial loss ─────────────────────── adv_g = real.new_zeros(1).squeeze() - if use_adversarial: + if use_adversarial and epoch > kl_warmup_epochs: adv_g = hinge_g_loss(patchgan(recon)) vae_loss = vae_loss + lambda_adversarial * adv_g