Testing VAE until it works - v1

This commit is contained in:
Johnny Fernandes
2026-05-02 13:11:56 +01:00
parent c7804d2984
commit ec8d4ae336
84 changed files with 9 additions and 1744 deletions
+6 -13
View File
@@ -408,11 +408,12 @@ def train_vae(
Config toggles:
lambda_perceptual > 0 → VGG-16 perceptual loss (Phase 3.2+)
lambda_adversarial > 0 → PatchGAN hinge adversarial loss (Phase 3.3)
free_bits > 0 → per-dimension KL free bits (prevents posterior
collapse and KL explosion)
Loss: L = L_mse + λ_perc·L_vgg + λ_adv·L_adv + β_kl·L_kl
KL is computed as mean over latent dimensions (scale-invariant), so
beta_kl is comparable across different latent_dim values.
AMP is intentionally disabled for VAE training — mixed-precision float16
overflows when the KL divergence spikes, producing NaN cascades that
corrupt the model irrecoverably. All VAE + perceptual + PatchGAN
@@ -432,7 +433,6 @@ def train_vae(
lambda_perceptual = cfg.get("lambda_perceptual", 0.0)
lambda_adversarial = cfg.get("lambda_adversarial", 0.0)
lr_d = cfg.get("lr_d", 1e-4)
free_bits_val = cfg.get("free_bits", 0.0)
grad_clip = cfg.get("grad_clip", 1.0)
ema_decay = cfg.get("ema_decay", 0.9999)
sample_interval = cfg.get("sample_interval", 10)
@@ -441,7 +441,6 @@ def train_vae(
use_perceptual = lambda_perceptual > 0
use_adversarial = lambda_adversarial > 0
use_free_bits = free_bits_val > 0
loader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=True,
@@ -509,7 +508,7 @@ def train_vae(
print(
f"Device: {device} AMP: disabled (float32) Batches/epoch: {len(loader)}"
f" β_kl={beta_kl} (warmup {kl_warmup_epochs}ep) λ_perc={lambda_perceptual}"
f" λ_adv={lambda_adversarial} free_bits={free_bits_val}"
f" λ_adv={lambda_adversarial}"
)
t_start = time.time()
@@ -532,14 +531,8 @@ def train_vae(
recon, mu, log_var = vae(real)
mse = F.mse_loss(recon, real)
# KL divergence with optional free bits
kl_per_dim = -0.5 * (1 + log_var - mu.pow(2) - log_var.exp()) # (B, latent_dim)
if use_free_bits:
# Free bits: ensure each dimension contributes at least free_bits_val KL.
# Dimensions below the threshold are raised to it, preventing posterior
# collapse (dimensions that go to 0) while still penalising large KL.
kl_per_dim = torch.clamp(kl_per_dim, min=free_bits_val)
kl = kl_per_dim.sum(1).mean()
# KL divergence: mean over latent dims (scale-invariant w.r.t. latent_dim)
kl = (-0.5 * (1 + log_var - mu.pow(2) - log_var.exp())).mean()
perc = perc_fn(recon, real) if use_perceptual else real.new_zeros(1).squeeze()
vae_loss = mse + current_beta * kl + lambda_perceptual * perc