From f89d7dcfda9d11f26ab97b21cd1789ec61b1141f Mon Sep 17 00:00:00 2001 From: Johnny Fernandes Date: Sat, 2 May 2026 00:32:45 +0100 Subject: [PATCH] Trying a few different VAE settings --- generator/configs/phase3/_base_phase3.json | 2 + generator/configs/phase3/p3_1_vae.json | 4 +- .../configs/phase3/p3_2_vae_perceptual.json | 4 +- .../configs/phase3/p3_3_vae_patchgan.json | 4 +- generator/src/models/vae.py | 3 +- generator/src/training/trainer.py | 97 ++++++++++++------- 6 files changed, 71 insertions(+), 43 deletions(-) diff --git a/generator/configs/phase3/_base_phase3.json b/generator/configs/phase3/_base_phase3.json index 3c3e438..084b8b4 100644 --- a/generator/configs/phase3/_base_phase3.json +++ b/generator/configs/phase3/_base_phase3.json @@ -7,6 +7,8 @@ "model": "vae", "latent_dim": 256, "ngf": 64, + "free_bits": 0.1, + "grad_clip": 1.0, "sample_interval": 10, "fid_interval": 25, "fid_n_real": 5000 diff --git a/generator/configs/phase3/p3_1_vae.json b/generator/configs/phase3/p3_1_vae.json index 75b6af9..0951564 100644 --- a/generator/configs/phase3/p3_1_vae.json +++ b/generator/configs/phase3/p3_1_vae.json @@ -1,8 +1,8 @@ { "extends": "_base_phase3.json", "run_name": "p3_1_vae", - "lr": 1e-3, - "beta_kl": 1.0, + "lr": 5e-4, + "beta_kl": 0.005, "lambda_perceptual": 0.0, "lambda_adversarial": 0.0 } diff --git a/generator/configs/phase3/p3_2_vae_perceptual.json b/generator/configs/phase3/p3_2_vae_perceptual.json index 8116ea2..93173e8 100644 --- a/generator/configs/phase3/p3_2_vae_perceptual.json +++ b/generator/configs/phase3/p3_2_vae_perceptual.json @@ -1,8 +1,8 @@ { "extends": "_base_phase3.json", "run_name": "p3_2_vae_perceptual", - "lr": 1e-3, - "beta_kl": 0.0001, + "lr": 5e-4, + "beta_kl": 0.005, "lambda_perceptual": 0.1, "lambda_adversarial": 0.0 } diff --git a/generator/configs/phase3/p3_3_vae_patchgan.json b/generator/configs/phase3/p3_3_vae_patchgan.json index 1c6e443..8fa823c 100644 --- a/generator/configs/phase3/p3_3_vae_patchgan.json +++ b/generator/configs/phase3/p3_3_vae_patchgan.json @@ -1,9 +1,9 @@ { "extends": "_base_phase3.json", "run_name": "p3_3_vae_patchgan", - "lr": 1e-3, + "lr": 5e-4, "lr_d": 1e-4, - "beta_kl": 0.0001, + "beta_kl": 0.005, "lambda_perceptual": 0.1, "lambda_adversarial": 0.1, "ndf_patch": 64 diff --git a/generator/src/models/vae.py b/generator/src/models/vae.py index 285ad56..55ffa86 100644 --- a/generator/src/models/vae.py +++ b/generator/src/models/vae.py @@ -98,7 +98,8 @@ class VAE(nn.Module): def encode(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: h = self.encoder(x).flatten(1) - return self.fc_mu(h), self.fc_lv(h) + log_var = self.fc_lv(h).clamp(-10.0, 10.0) + return self.fc_mu(h), log_var def reparameterize(self, mu: torch.Tensor, log_var: torch.Tensor) -> torch.Tensor: std = torch.exp(0.5 * log_var) diff --git a/generator/src/training/trainer.py b/generator/src/training/trainer.py index 4061525..8ea87e8 100644 --- a/generator/src/training/trainer.py +++ b/generator/src/training/trainer.py @@ -403,13 +403,20 @@ def train_vae( run_name: str, device: str = "cuda", ) -> dict: - """VAE training loop covering Phase 3.1 – 3.3. + """VAE training loop covering Phase 3.1 – 3.3 and Phase 5. Config toggles: lambda_perceptual > 0 → VGG-16 perceptual loss (Phase 3.2+) lambda_adversarial > 0 → PatchGAN hinge adversarial loss (Phase 3.3) + free_bits > 0 → per-dimension KL free bits (prevents posterior + collapse and KL explosion) Loss: L = L_mse + λ_perc·L_vgg + λ_adv·L_adv + β_kl·L_kl + + AMP is intentionally disabled for VAE training — mixed-precision float16 + overflows when the KL divergence spikes, producing NaN cascades that + corrupt the model irrecoverably. All VAE + perceptual + PatchGAN + computation runs in float32. """ device = torch.device(device if torch.cuda.is_available() else "cpu") vae = vae.to(device) @@ -425,6 +432,8 @@ def train_vae( lambda_perceptual = cfg.get("lambda_perceptual", 0.0) lambda_adversarial = cfg.get("lambda_adversarial", 0.0) lr_d = cfg.get("lr_d", 1e-4) + free_bits_val = cfg.get("free_bits", 0.0) + grad_clip = cfg.get("grad_clip", 1.0) ema_decay = cfg.get("ema_decay", 0.9999) sample_interval = cfg.get("sample_interval", 10) fid_interval = cfg.get("fid_interval", 25) @@ -432,6 +441,7 @@ def train_vae( use_perceptual = lambda_perceptual > 0 use_adversarial = lambda_adversarial > 0 + use_free_bits = free_bits_val > 0 loader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True, @@ -440,8 +450,8 @@ def train_vae( ) opt_vae = torch.optim.Adam(vae.parameters(), lr=lr) - use_amp = device.type == "cuda" - scaler = _GradScaler("cuda", enabled=use_amp) + # AMP disabled — float16 overflows on KL spikes, causing NaN cascades + use_amp = False # KL warmup: linearly ramp beta_kl from 0 to target over first 20% of training kl_warmup_epochs = max(1, epochs // 5) @@ -456,11 +466,10 @@ def train_vae( perc_fn = None patchgan = None opt_d = None - scaler_d = None if use_perceptual: from src.training.perceptual import PerceptualLoss - perc_fn = PerceptualLoss().to(device) + perc_fn = PerceptualLoss().to(device).float() print("Perceptual loss: VGG-16 relu1_2 + relu2_2 + relu3_3") if use_adversarial: @@ -468,15 +477,14 @@ def train_vae( patchgan = PatchGANDiscriminator( ndf=cfg.get("ndf_patch", 64), image_size=cfg.get("image_size", 64), - ).to(device) + ).to(device).float() opt_d = torch.optim.Adam(patchgan.parameters(), lr=lr_d, betas=(0.5, 0.999)) - scaler_d = _GradScaler("cuda", enabled=use_amp) sched_d = torch.optim.lr_scheduler.LambdaLR( opt_d, lr_lambda=lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / max(epochs - decay_start, 1))) n_d = sum(p.numel() for p in patchgan.parameters()) print(f"PatchGAN: {n_d:,} params") else: - hinge_d_loss = hinge_g_loss = None # satisfy linter, never called + hinge_d_loss = hinge_g_loss = None # never called # ── Fixed seeds for consistent visualisation ────────────────────────── fixed_z = torch.randn(16, latent_dim, device=device) @@ -497,9 +505,11 @@ def train_vae( "adv_g_loss": [], "adv_d_loss": [], "fid": {}, } best_fid = float("inf") + nan_skipped = 0 print( - f"Device: {device} AMP: {use_amp} Batches/epoch: {len(loader)}" - f" β_kl={beta_kl} (warmup {kl_warmup_epochs}ep) λ_perc={lambda_perceptual} λ_adv={lambda_adversarial}" + f"Device: {device} AMP: disabled (float32) Batches/epoch: {len(loader)}" + f" β_kl={beta_kl} (warmup {kl_warmup_epochs}ep) λ_perc={lambda_perceptual}" + f" λ_adv={lambda_adversarial} free_bits={free_bits_val}" ) t_start = time.time() @@ -513,43 +523,56 @@ def train_vae( n_batches = 0 for real in tqdm(loader, desc=f"Epoch {epoch}/{epochs}", leave=False): - real = real.to(device) + real = real.to(device).float() # KL warmup: ramp from 0 to beta_kl over kl_warmup_epochs current_beta = beta_kl * min(1.0, epoch / kl_warmup_epochs) - # ── VAE forward ─────────────────────────────────────────────── - with _autocast("cuda", enabled=use_amp): - recon, mu, log_var = vae(real) - mse = F.mse_loss(recon, real) - kl = -0.5 * (1 + log_var - mu.pow(2) - log_var.exp()).sum(1).mean() - perc = perc_fn(recon, real) if use_perceptual else real.new_zeros(1).squeeze() - vae_loss = mse + current_beta * kl + lambda_perceptual * perc + # ── VAE forward (float32, no AMP) ──────────────────────────── + recon, mu, log_var = vae(real) + mse = F.mse_loss(recon, real) + + # KL divergence with optional free bits + kl_per_dim = -0.5 * (1 + log_var - mu.pow(2) - log_var.exp()) # (B, latent_dim) + if use_free_bits: + # Free bits: ensure each dimension contributes at least free_bits_val KL. + # Dimensions below the threshold are raised to it, preventing posterior + # collapse (dimensions that go to 0) while still penalising large KL. + kl_per_dim = torch.clamp(kl_per_dim, min=free_bits_val) + kl = kl_per_dim.sum(1).mean() + + perc = perc_fn(recon, real) if use_perceptual else real.new_zeros(1).squeeze() + vae_loss = mse + current_beta * kl + lambda_perceptual * perc + + # ── NaN/Inf guard ──────────────────────────────────────────── + if not torch.isfinite(vae_loss): + nan_skipped += 1 + opt_vae.zero_grad() + continue # ── PatchGAN discriminator step ─────────────────────────────── adv_d = real.new_zeros(1).squeeze() if use_adversarial: opt_d.zero_grad() - with _autocast("cuda", enabled=use_amp): - d_real = patchgan(real) - d_fake = patchgan(recon.detach()) - adv_d = hinge_d_loss(d_real, d_fake) - scaler_d.scale(adv_d).backward() - scaler_d.step(opt_d) - scaler_d.update() + d_real = patchgan(real) + d_fake = patchgan(recon.detach()) + adv_d = hinge_d_loss(d_real, d_fake) + if torch.isfinite(adv_d): + adv_d.backward() + torch.nn.utils.clip_grad_norm_(patchgan.parameters(), grad_clip) + opt_d.step() # ── PatchGAN generator adversarial loss ─────────────────────── adv_g = real.new_zeros(1).squeeze() if use_adversarial: - with _autocast("cuda", enabled=use_amp): - adv_g = hinge_g_loss(patchgan(recon)) - vae_loss = vae_loss + lambda_adversarial * adv_g + adv_g = hinge_g_loss(patchgan(recon)) + vae_loss = vae_loss + lambda_adversarial * adv_g # ── VAE backward ────────────────────────────────────────────── opt_vae.zero_grad() - scaler.scale(vae_loss).backward() - scaler.step(opt_vae) - scaler.update() + vae_loss.backward() + torch.nn.utils.clip_grad_norm_(vae.parameters(), grad_clip) + opt_vae.step() ema.update(vae) recon_sum += mse.item() @@ -559,11 +582,11 @@ def train_vae( adv_d_sum += adv_d.item() n_batches += 1 - avg_r = recon_sum / n_batches - avg_k = kl_sum / n_batches - avg_p = perc_sum / n_batches - avg_g = adv_g_sum / n_batches - avg_d = adv_d_sum / n_batches + avg_r = recon_sum / max(n_batches, 1) + avg_k = kl_sum / max(n_batches, 1) + avg_p = perc_sum / max(n_batches, 1) + avg_g = adv_g_sum / max(n_batches, 1) + avg_d = adv_d_sum / max(n_batches, 1) history["recon_loss"].append(avg_r) history["kl_loss"].append(avg_k) history["perc_loss"].append(avg_p) @@ -574,6 +597,7 @@ def train_vae( f"[{epoch:03d}/{epochs}] " f"MSE: {avg_r:.4f} KL: {avg_k:.2f} β={current_beta:.6f} " f"Perc: {avg_p:.4f} AdvG: {avg_g:.4f} AdvD: {avg_d:.4f}" + f" (NaN skipped: {nan_skipped})" ) if epoch % sample_interval == 0: @@ -607,6 +631,7 @@ def train_vae( if patchgan is not None: torch.save(patchgan.state_dict(), save_dir / f"{run_name}_final_patchgan.pt") history["train_time_s"] = time.time() - t_start + print(f"Total NaN-skipped batches: {nan_skipped}") return history