Testing VAE until it works - v1
This commit is contained in:
@@ -408,11 +408,12 @@ def train_vae(
|
||||
Config toggles:
|
||||
lambda_perceptual > 0 → VGG-16 perceptual loss (Phase 3.2+)
|
||||
lambda_adversarial > 0 → PatchGAN hinge adversarial loss (Phase 3.3)
|
||||
free_bits > 0 → per-dimension KL free bits (prevents posterior
|
||||
collapse and KL explosion)
|
||||
|
||||
Loss: L = L_mse + λ_perc·L_vgg + λ_adv·L_adv + β_kl·L_kl
|
||||
|
||||
KL is computed as mean over latent dimensions (scale-invariant), so
|
||||
beta_kl is comparable across different latent_dim values.
|
||||
|
||||
AMP is intentionally disabled for VAE training — mixed-precision float16
|
||||
overflows when the KL divergence spikes, producing NaN cascades that
|
||||
corrupt the model irrecoverably. All VAE + perceptual + PatchGAN
|
||||
@@ -432,7 +433,6 @@ def train_vae(
|
||||
lambda_perceptual = cfg.get("lambda_perceptual", 0.0)
|
||||
lambda_adversarial = cfg.get("lambda_adversarial", 0.0)
|
||||
lr_d = cfg.get("lr_d", 1e-4)
|
||||
free_bits_val = cfg.get("free_bits", 0.0)
|
||||
grad_clip = cfg.get("grad_clip", 1.0)
|
||||
ema_decay = cfg.get("ema_decay", 0.9999)
|
||||
sample_interval = cfg.get("sample_interval", 10)
|
||||
@@ -441,7 +441,6 @@ def train_vae(
|
||||
|
||||
use_perceptual = lambda_perceptual > 0
|
||||
use_adversarial = lambda_adversarial > 0
|
||||
use_free_bits = free_bits_val > 0
|
||||
|
||||
loader = DataLoader(
|
||||
train_dataset, batch_size=batch_size, shuffle=True,
|
||||
@@ -509,7 +508,7 @@ def train_vae(
|
||||
print(
|
||||
f"Device: {device} AMP: disabled (float32) Batches/epoch: {len(loader)}"
|
||||
f" β_kl={beta_kl} (warmup {kl_warmup_epochs}ep) λ_perc={lambda_perceptual}"
|
||||
f" λ_adv={lambda_adversarial} free_bits={free_bits_val}"
|
||||
f" λ_adv={lambda_adversarial}"
|
||||
)
|
||||
|
||||
t_start = time.time()
|
||||
@@ -532,14 +531,8 @@ def train_vae(
|
||||
recon, mu, log_var = vae(real)
|
||||
mse = F.mse_loss(recon, real)
|
||||
|
||||
# KL divergence with optional free bits
|
||||
kl_per_dim = -0.5 * (1 + log_var - mu.pow(2) - log_var.exp()) # (B, latent_dim)
|
||||
if use_free_bits:
|
||||
# Free bits: ensure each dimension contributes at least free_bits_val KL.
|
||||
# Dimensions below the threshold are raised to it, preventing posterior
|
||||
# collapse (dimensions that go to 0) while still penalising large KL.
|
||||
kl_per_dim = torch.clamp(kl_per_dim, min=free_bits_val)
|
||||
kl = kl_per_dim.sum(1).mean()
|
||||
# KL divergence: mean over latent dims (scale-invariant w.r.t. latent_dim)
|
||||
kl = (-0.5 * (1 + log_var - mu.pow(2) - log_var.exp())).mean()
|
||||
|
||||
perc = perc_fn(recon, real) if use_perceptual else real.new_zeros(1).squeeze()
|
||||
vae_loss = mse + current_beta * kl + lambda_perceptual * perc
|
||||
|
||||
Reference in New Issue
Block a user