VAE refactor
This commit is contained in:
@@ -403,13 +403,20 @@ def train_vae(
|
||||
run_name: str,
|
||||
device: str = "cuda",
|
||||
) -> dict:
|
||||
"""VAE training loop covering Phase 3.1 – 3.3.
|
||||
"""VAE training loop covering Phase 3.1 – 3.3 and Phase 5.
|
||||
|
||||
Config toggles:
|
||||
lambda_perceptual > 0 → VGG-16 perceptual loss (Phase 3.2+)
|
||||
lambda_adversarial > 0 → PatchGAN hinge adversarial loss (Phase 3.3)
|
||||
free_bits > 0 → per-dimension KL free bits (prevents posterior
|
||||
collapse and KL explosion)
|
||||
|
||||
Loss: L = L_mse + λ_perc·L_vgg + λ_adv·L_adv + β_kl·L_kl
|
||||
|
||||
AMP is intentionally disabled for VAE training — mixed-precision float16
|
||||
overflows when the KL divergence spikes, producing NaN cascades that
|
||||
corrupt the model irrecoverably. All VAE + perceptual + PatchGAN
|
||||
computation runs in float32.
|
||||
"""
|
||||
device = torch.device(device if torch.cuda.is_available() else "cpu")
|
||||
vae = vae.to(device)
|
||||
@@ -425,6 +432,8 @@ def train_vae(
|
||||
lambda_perceptual = cfg.get("lambda_perceptual", 0.0)
|
||||
lambda_adversarial = cfg.get("lambda_adversarial", 0.0)
|
||||
lr_d = cfg.get("lr_d", 1e-4)
|
||||
free_bits_val = cfg.get("free_bits", 0.0)
|
||||
grad_clip = cfg.get("grad_clip", 1.0)
|
||||
ema_decay = cfg.get("ema_decay", 0.9999)
|
||||
sample_interval = cfg.get("sample_interval", 10)
|
||||
fid_interval = cfg.get("fid_interval", 25)
|
||||
@@ -432,6 +441,7 @@ def train_vae(
|
||||
|
||||
use_perceptual = lambda_perceptual > 0
|
||||
use_adversarial = lambda_adversarial > 0
|
||||
use_free_bits = free_bits_val > 0
|
||||
|
||||
loader = DataLoader(
|
||||
train_dataset, batch_size=batch_size, shuffle=True,
|
||||
@@ -440,8 +450,8 @@ def train_vae(
|
||||
)
|
||||
|
||||
opt_vae = torch.optim.Adam(vae.parameters(), lr=lr)
|
||||
use_amp = device.type == "cuda"
|
||||
scaler = _GradScaler("cuda", enabled=use_amp)
|
||||
# AMP disabled — float16 overflows on KL spikes, causing NaN cascades
|
||||
use_amp = False
|
||||
|
||||
# KL warmup: linearly ramp beta_kl from 0 to target over first 20% of training
|
||||
kl_warmup_epochs = max(1, epochs // 5)
|
||||
@@ -456,11 +466,10 @@ def train_vae(
|
||||
perc_fn = None
|
||||
patchgan = None
|
||||
opt_d = None
|
||||
scaler_d = None
|
||||
|
||||
if use_perceptual:
|
||||
from src.training.perceptual import PerceptualLoss
|
||||
perc_fn = PerceptualLoss().to(device)
|
||||
perc_fn = PerceptualLoss().to(device).float()
|
||||
print("Perceptual loss: VGG-16 relu1_2 + relu2_2 + relu3_3")
|
||||
|
||||
if use_adversarial:
|
||||
@@ -468,15 +477,14 @@ def train_vae(
|
||||
patchgan = PatchGANDiscriminator(
|
||||
ndf=cfg.get("ndf_patch", 64),
|
||||
image_size=cfg.get("image_size", 64),
|
||||
).to(device)
|
||||
).to(device).float()
|
||||
opt_d = torch.optim.Adam(patchgan.parameters(), lr=lr_d, betas=(0.5, 0.999))
|
||||
scaler_d = _GradScaler("cuda", enabled=use_amp)
|
||||
sched_d = torch.optim.lr_scheduler.LambdaLR(
|
||||
opt_d, lr_lambda=lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / max(epochs - decay_start, 1)))
|
||||
n_d = sum(p.numel() for p in patchgan.parameters())
|
||||
print(f"PatchGAN: {n_d:,} params")
|
||||
else:
|
||||
hinge_d_loss = hinge_g_loss = None # satisfy linter, never called
|
||||
hinge_d_loss = hinge_g_loss = None # never called
|
||||
|
||||
# ── Fixed seeds for consistent visualisation ──────────────────────────
|
||||
fixed_z = torch.randn(16, latent_dim, device=device)
|
||||
@@ -497,9 +505,11 @@ def train_vae(
|
||||
"adv_g_loss": [], "adv_d_loss": [], "fid": {},
|
||||
}
|
||||
best_fid = float("inf")
|
||||
nan_skipped = 0
|
||||
print(
|
||||
f"Device: {device} AMP: {use_amp} Batches/epoch: {len(loader)}"
|
||||
f" β_kl={beta_kl} (warmup {kl_warmup_epochs}ep) λ_perc={lambda_perceptual} λ_adv={lambda_adversarial}"
|
||||
f"Device: {device} AMP: disabled (float32) Batches/epoch: {len(loader)}"
|
||||
f" β_kl={beta_kl} (warmup {kl_warmup_epochs}ep) λ_perc={lambda_perceptual}"
|
||||
f" λ_adv={lambda_adversarial} free_bits={free_bits_val}"
|
||||
)
|
||||
|
||||
t_start = time.time()
|
||||
@@ -513,43 +523,56 @@ def train_vae(
|
||||
n_batches = 0
|
||||
|
||||
for real in tqdm(loader, desc=f"Epoch {epoch}/{epochs}", leave=False):
|
||||
real = real.to(device)
|
||||
real = real.to(device).float()
|
||||
|
||||
# KL warmup: ramp from 0 to beta_kl over kl_warmup_epochs
|
||||
current_beta = beta_kl * min(1.0, epoch / kl_warmup_epochs)
|
||||
|
||||
# ── VAE forward ───────────────────────────────────────────────
|
||||
with _autocast("cuda", enabled=use_amp):
|
||||
recon, mu, log_var = vae(real)
|
||||
mse = F.mse_loss(recon, real)
|
||||
kl = -0.5 * (1 + log_var - mu.pow(2) - log_var.exp()).sum(1).mean()
|
||||
perc = perc_fn(recon, real) if use_perceptual else real.new_zeros(1).squeeze()
|
||||
vae_loss = mse + current_beta * kl + lambda_perceptual * perc
|
||||
# ── VAE forward (float32, no AMP) ────────────────────────────
|
||||
recon, mu, log_var = vae(real)
|
||||
mse = F.mse_loss(recon, real)
|
||||
|
||||
# KL divergence with optional free bits
|
||||
kl_per_dim = -0.5 * (1 + log_var - mu.pow(2) - log_var.exp()) # (B, latent_dim)
|
||||
if use_free_bits:
|
||||
# Free bits: ensure each dimension contributes at least free_bits_val KL.
|
||||
# Dimensions below the threshold are raised to it, preventing posterior
|
||||
# collapse (dimensions that go to 0) while still penalising large KL.
|
||||
kl_per_dim = torch.clamp(kl_per_dim, min=free_bits_val)
|
||||
kl = kl_per_dim.sum(1).mean()
|
||||
|
||||
perc = perc_fn(recon, real) if use_perceptual else real.new_zeros(1).squeeze()
|
||||
vae_loss = mse + current_beta * kl + lambda_perceptual * perc
|
||||
|
||||
# ── NaN/Inf guard ────────────────────────────────────────────
|
||||
if not torch.isfinite(vae_loss):
|
||||
nan_skipped += 1
|
||||
opt_vae.zero_grad()
|
||||
continue
|
||||
|
||||
# ── PatchGAN discriminator step ───────────────────────────────
|
||||
adv_d = real.new_zeros(1).squeeze()
|
||||
if use_adversarial:
|
||||
opt_d.zero_grad()
|
||||
with _autocast("cuda", enabled=use_amp):
|
||||
d_real = patchgan(real)
|
||||
d_fake = patchgan(recon.detach())
|
||||
adv_d = hinge_d_loss(d_real, d_fake)
|
||||
scaler_d.scale(adv_d).backward()
|
||||
scaler_d.step(opt_d)
|
||||
scaler_d.update()
|
||||
d_real = patchgan(real)
|
||||
d_fake = patchgan(recon.detach())
|
||||
adv_d = hinge_d_loss(d_real, d_fake)
|
||||
if torch.isfinite(adv_d):
|
||||
adv_d.backward()
|
||||
torch.nn.utils.clip_grad_norm_(patchgan.parameters(), grad_clip)
|
||||
opt_d.step()
|
||||
|
||||
# ── PatchGAN generator adversarial loss ───────────────────────
|
||||
adv_g = real.new_zeros(1).squeeze()
|
||||
if use_adversarial:
|
||||
with _autocast("cuda", enabled=use_amp):
|
||||
adv_g = hinge_g_loss(patchgan(recon))
|
||||
vae_loss = vae_loss + lambda_adversarial * adv_g
|
||||
adv_g = hinge_g_loss(patchgan(recon))
|
||||
vae_loss = vae_loss + lambda_adversarial * adv_g
|
||||
|
||||
# ── VAE backward ──────────────────────────────────────────────
|
||||
opt_vae.zero_grad()
|
||||
scaler.scale(vae_loss).backward()
|
||||
scaler.step(opt_vae)
|
||||
scaler.update()
|
||||
vae_loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(vae.parameters(), grad_clip)
|
||||
opt_vae.step()
|
||||
ema.update(vae)
|
||||
|
||||
recon_sum += mse.item()
|
||||
@@ -559,11 +582,11 @@ def train_vae(
|
||||
adv_d_sum += adv_d.item()
|
||||
n_batches += 1
|
||||
|
||||
avg_r = recon_sum / n_batches
|
||||
avg_k = kl_sum / n_batches
|
||||
avg_p = perc_sum / n_batches
|
||||
avg_g = adv_g_sum / n_batches
|
||||
avg_d = adv_d_sum / n_batches
|
||||
avg_r = recon_sum / max(n_batches, 1)
|
||||
avg_k = kl_sum / max(n_batches, 1)
|
||||
avg_p = perc_sum / max(n_batches, 1)
|
||||
avg_g = adv_g_sum / max(n_batches, 1)
|
||||
avg_d = adv_d_sum / max(n_batches, 1)
|
||||
history["recon_loss"].append(avg_r)
|
||||
history["kl_loss"].append(avg_k)
|
||||
history["perc_loss"].append(avg_p)
|
||||
@@ -574,6 +597,7 @@ def train_vae(
|
||||
f"[{epoch:03d}/{epochs}] "
|
||||
f"MSE: {avg_r:.4f} KL: {avg_k:.2f} β={current_beta:.6f} "
|
||||
f"Perc: {avg_p:.4f} AdvG: {avg_g:.4f} AdvD: {avg_d:.4f}"
|
||||
f" (NaN skipped: {nan_skipped})"
|
||||
)
|
||||
|
||||
if epoch % sample_interval == 0:
|
||||
@@ -607,6 +631,7 @@ def train_vae(
|
||||
if patchgan is not None:
|
||||
torch.save(patchgan.state_dict(), save_dir / f"{run_name}_final_patchgan.pt")
|
||||
history["train_time_s"] = time.time() - t_start
|
||||
print(f"Total NaN-skipped batches: {nan_skipped}")
|
||||
return history
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user