VAE fix w/ new results

This commit is contained in:
Johnny Fernandes
2026-05-02 00:32:45 +01:00
parent 1a7f67ab9c
commit bac52bc15e
90 changed files with 1197 additions and 1106 deletions
+11 -4
View File
@@ -22,17 +22,21 @@ def _init_weights(m):
nn.init.normal_(m.weight, 0.0, 0.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.BatchNorm2d) and m.weight is not None:
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)) and m.weight is not None:
nn.init.normal_(m.weight, 1.0, 0.02)
nn.init.zeros_(m.bias)
def _norm(channels: int) -> nn.GroupNorm:
return nn.GroupNorm(8, channels)
def _upsample_block(in_ch: int, out_ch: int) -> nn.Sequential:
"""Nearest-neighbour upsample followed by a 3×3 conv — no checkerboard."""
return nn.Sequential(
nn.Upsample(scale_factor=2, mode="nearest"),
nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
nn.BatchNorm2d(out_ch),
_norm(out_ch),
nn.ReLU(inplace=True),
)
@@ -65,7 +69,7 @@ class VAE(nn.Module):
for _ in range(n_down - 1):
enc_layers += [
nn.Conv2d(ch, ch * 2, 4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(ch * 2),
_norm(ch * 2),
nn.LeakyReLU(0.2, inplace=True),
]
ch *= 2
@@ -98,9 +102,12 @@ class VAE(nn.Module):
def encode(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
h = self.encoder(x).flatten(1)
return self.fc_mu(h), self.fc_lv(h)
log_var = self.fc_lv(h).clamp(-10.0, 10.0)
return self.fc_mu(h), log_var
def reparameterize(self, mu: torch.Tensor, log_var: torch.Tensor) -> torch.Tensor:
if not self.training:
return mu
std = torch.exp(0.5 * log_var)
return mu + std * torch.randn_like(std)
+2
View File
@@ -20,3 +20,5 @@ class EMA:
def update(self, model: nn.Module) -> None:
for p_ema, p in zip(self.model.parameters(), model.parameters()):
p_ema.data.mul_(self.decay).add_(p.data, alpha=1.0 - self.decay)
for b_ema, b in zip(self.model.buffers(), model.buffers()):
b_ema.copy_(b)
+3 -2
View File
@@ -11,14 +11,15 @@ from torchmetrics.image.fid import FrechetInceptionDistance
class FIDEvaluator:
def __init__(self, real_dataset, n_real: int = 10_000, device: str = "cuda"):
def __init__(self, real_dataset, n_real: int = 10_000, device: str = "cuda",
num_workers: int = 2):
self.device = torch.device(device if torch.cuda.is_available() else "cpu")
self.n_real = n_real
# Cache real images as a CPU tensor ([-1, 1] range)
imgs_list = []
loader = DataLoader(real_dataset, batch_size=256, shuffle=False,
num_workers=4, drop_last=False)
num_workers=num_workers, drop_last=False)
for batch in loader:
imgs_list.append(batch.cpu())
if sum(x.size(0) for x in imgs_list) >= n_real:
+67 -44
View File
@@ -66,7 +66,7 @@ def train_dcgan(
loader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=True,
num_workers=min(4, os.cpu_count() or 1),
num_workers=cfg.get("num_workers", min(4, os.cpu_count() or 1)),
pin_memory=(device.type == "cuda"), drop_last=True,
)
opt_g = torch.optim.Adam(generator.parameters(), lr=lr_g, betas=(beta1, beta2))
@@ -86,7 +86,8 @@ def train_dcgan(
save_dir.mkdir(parents=True, exist_ok=True)
samples_dir = save_dir.parent / "samples" / run_name
fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device))
fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device),
num_workers=cfg.get("num_workers", min(4, os.cpu_count() or 1)))
history = {"g_loss": [], "d_loss": [], "d_real": [], "d_fake": [], "fid": {}}
best_fid = float("inf")
@@ -239,7 +240,7 @@ def train_wgan(
loader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=True,
num_workers=min(4, os.cpu_count() or 1),
num_workers=cfg.get("num_workers", min(4, os.cpu_count() or 1)),
pin_memory=(device.type == "cuda"), drop_last=True,
)
opt_g = torch.optim.Adam(generator.parameters(), lr=lr_g, betas=(beta1, beta2))
@@ -257,7 +258,8 @@ def train_wgan(
save_dir.mkdir(parents=True, exist_ok=True)
samples_dir = save_dir.parent / "samples" / run_name
fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device))
fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device),
num_workers=cfg.get("num_workers", min(4, os.cpu_count() or 1)))
history = {"g_loss": [], "w_dist": [], "gp": [], "d_real": [], "d_fake": [], "fid": {}}
best_fid = float("inf")
@@ -391,7 +393,6 @@ def _save_vae_samples(
# Interleave real / reconstruction pairs
pairs = torch.stack([real, recon], dim=1).flatten(0, 1)
save_image(pairs, samples_dir / f"epoch_{epoch:04d}_recon.png", nrow=4)
vae.train()
def train_vae(
@@ -403,13 +404,21 @@ def train_vae(
run_name: str,
device: str = "cuda",
) -> dict:
"""VAE training loop covering Phase 3.1 3.3.
"""VAE training loop covering Phase 3.1 3.3 and Phase 5.
Config toggles:
lambda_perceptual > 0 → VGG-16 perceptual loss (Phase 3.2+)
lambda_adversarial > 0 → PatchGAN hinge adversarial loss (Phase 3.3)
Loss: L = L_mse + λ_perc·L_vgg + λ_adv·L_adv + β_kl·L_kl
KL is computed as mean over latent dimensions (scale-invariant), so
beta_kl is comparable across different latent_dim values.
AMP is intentionally disabled for VAE training — mixed-precision float16
overflows when the KL divergence spikes, producing NaN cascades that
corrupt the model irrecoverably. All VAE + perceptual + PatchGAN
computation runs in float32.
"""
device = torch.device(device if torch.cuda.is_available() else "cpu")
vae = vae.to(device)
@@ -425,6 +434,7 @@ def train_vae(
lambda_perceptual = cfg.get("lambda_perceptual", 0.0)
lambda_adversarial = cfg.get("lambda_adversarial", 0.0)
lr_d = cfg.get("lr_d", 1e-4)
grad_clip = cfg.get("grad_clip", 1.0)
ema_decay = cfg.get("ema_decay", 0.9999)
sample_interval = cfg.get("sample_interval", 10)
fid_interval = cfg.get("fid_interval", 25)
@@ -435,13 +445,13 @@ def train_vae(
loader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=True,
num_workers=min(4, os.cpu_count() or 1),
num_workers=cfg.get("num_workers", min(4, os.cpu_count() or 1)),
pin_memory=(device.type == "cuda"), drop_last=True,
)
opt_vae = torch.optim.Adam(vae.parameters(), lr=lr)
use_amp = device.type == "cuda"
scaler = _GradScaler("cuda", enabled=use_amp)
# AMP disabled — float16 overflows on KL spikes, causing NaN cascades
use_amp = False
# KL warmup: linearly ramp beta_kl from 0 to target over first 20% of training
kl_warmup_epochs = max(1, epochs // 5)
@@ -456,11 +466,10 @@ def train_vae(
perc_fn = None
patchgan = None
opt_d = None
scaler_d = None
if use_perceptual:
from src.training.perceptual import PerceptualLoss
perc_fn = PerceptualLoss().to(device)
perc_fn = PerceptualLoss().to(device).float()
print("Perceptual loss: VGG-16 relu1_2 + relu2_2 + relu3_3")
if use_adversarial:
@@ -468,15 +477,14 @@ def train_vae(
patchgan = PatchGANDiscriminator(
ndf=cfg.get("ndf_patch", 64),
image_size=cfg.get("image_size", 64),
).to(device)
).to(device).float()
opt_d = torch.optim.Adam(patchgan.parameters(), lr=lr_d, betas=(0.5, 0.999))
scaler_d = _GradScaler("cuda", enabled=use_amp)
sched_d = torch.optim.lr_scheduler.LambdaLR(
opt_d, lr_lambda=lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / max(epochs - decay_start, 1)))
n_d = sum(p.numel() for p in patchgan.parameters())
print(f"PatchGAN: {n_d:,} params")
else:
hinge_d_loss = hinge_g_loss = None # satisfy linter, never called
hinge_d_loss = hinge_g_loss = None # never called
# ── Fixed seeds for consistent visualisation ──────────────────────────
fixed_z = torch.randn(16, latent_dim, device=device)
@@ -490,16 +498,19 @@ def train_vae(
save_dir.mkdir(parents=True, exist_ok=True)
samples_dir = save_dir.parent / "samples" / run_name
fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device))
fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device),
num_workers=cfg.get("num_workers", min(4, os.cpu_count() or 1)))
history = {
"recon_loss": [], "kl_loss": [], "perc_loss": [],
"adv_g_loss": [], "adv_d_loss": [], "fid": {},
}
best_fid = float("inf")
nan_skipped = 0
print(
f"Device: {device} AMP: {use_amp} Batches/epoch: {len(loader)}"
f" β_kl={beta_kl} (warmup {kl_warmup_epochs}ep) λ_perc={lambda_perceptual} λ_adv={lambda_adversarial}"
f"Device: {device} AMP: disabled (float32) Batches/epoch: {len(loader)}"
f" β_kl={beta_kl} (warmup {kl_warmup_epochs}ep) λ_perc={lambda_perceptual}"
f" λ_adv={lambda_adversarial}"
)
t_start = time.time()
@@ -513,43 +524,52 @@ def train_vae(
n_batches = 0
for real in tqdm(loader, desc=f"Epoch {epoch}/{epochs}", leave=False):
real = real.to(device)
real = real.to(device).float()
# KL warmup: ramp from 0 to beta_kl over kl_warmup_epochs
current_beta = beta_kl * min(1.0, epoch / kl_warmup_epochs)
# ── VAE forward ───────────────────────────────────────────────
with _autocast("cuda", enabled=use_amp):
recon, mu, log_var = vae(real)
mse = F.mse_loss(recon, real)
kl = -0.5 * (1 + log_var - mu.pow(2) - log_var.exp()).sum(1).mean()
perc = perc_fn(recon, real) if use_perceptual else real.new_zeros(1).squeeze()
vae_loss = mse + current_beta * kl + lambda_perceptual * perc
# ── VAE forward (float32, no AMP) ────────────────────────────
recon, mu, log_var = vae(real)
mse = F.mse_loss(recon, real)
# KL divergence: mean over latent dims (scale-invariant w.r.t. latent_dim)
kl = (-0.5 * (1 + log_var - mu.pow(2) - log_var.exp())).mean()
perc = perc_fn(recon, real) if use_perceptual else real.new_zeros(1).squeeze()
vae_loss = mse + current_beta * kl + lambda_perceptual * perc
# ── NaN/Inf guard ────────────────────────────────────────────
if not torch.isfinite(vae_loss):
nan_skipped += 1
opt_vae.zero_grad()
continue
# ── PatchGAN discriminator step ───────────────────────────────
adv_d = real.new_zeros(1).squeeze()
if use_adversarial:
opt_d.zero_grad()
with _autocast("cuda", enabled=use_amp):
# Warmup: only start adversarial training after 20% of epochs
if epoch > kl_warmup_epochs:
opt_d.zero_grad()
d_real = patchgan(real)
d_fake = patchgan(recon.detach())
adv_d = hinge_d_loss(d_real, d_fake)
scaler_d.scale(adv_d).backward()
scaler_d.step(opt_d)
scaler_d.update()
if torch.isfinite(adv_d):
adv_d.backward()
torch.nn.utils.clip_grad_norm_(patchgan.parameters(), grad_clip)
opt_d.step()
# ── PatchGAN generator adversarial loss ───────────────────────
adv_g = real.new_zeros(1).squeeze()
if use_adversarial:
with _autocast("cuda", enabled=use_amp):
adv_g = hinge_g_loss(patchgan(recon))
vae_loss = vae_loss + lambda_adversarial * adv_g
if use_adversarial and epoch > kl_warmup_epochs:
adv_g = hinge_g_loss(patchgan(recon))
vae_loss = vae_loss + lambda_adversarial * adv_g
# ── VAE backward ──────────────────────────────────────────────
opt_vae.zero_grad()
scaler.scale(vae_loss).backward()
scaler.step(opt_vae)
scaler.update()
vae_loss.backward()
torch.nn.utils.clip_grad_norm_(vae.parameters(), grad_clip)
opt_vae.step()
ema.update(vae)
recon_sum += mse.item()
@@ -559,11 +579,11 @@ def train_vae(
adv_d_sum += adv_d.item()
n_batches += 1
avg_r = recon_sum / n_batches
avg_k = kl_sum / n_batches
avg_p = perc_sum / n_batches
avg_g = adv_g_sum / n_batches
avg_d = adv_d_sum / n_batches
avg_r = recon_sum / max(n_batches, 1)
avg_k = kl_sum / max(n_batches, 1)
avg_p = perc_sum / max(n_batches, 1)
avg_g = adv_g_sum / max(n_batches, 1)
avg_d = adv_d_sum / max(n_batches, 1)
history["recon_loss"].append(avg_r)
history["kl_loss"].append(avg_k)
history["perc_loss"].append(avg_p)
@@ -574,6 +594,7 @@ def train_vae(
f"[{epoch:03d}/{epochs}] "
f"MSE: {avg_r:.4f} KL: {avg_k:.2f} β={current_beta:.6f} "
f"Perc: {avg_p:.4f} AdvG: {avg_g:.4f} AdvD: {avg_d:.4f}"
f" (NaN skipped: {nan_skipped})"
)
if epoch % sample_interval == 0:
@@ -607,6 +628,7 @@ def train_vae(
if patchgan is not None:
torch.save(patchgan.state_dict(), save_dir / f"{run_name}_final_patchgan.pt")
history["train_time_s"] = time.time() - t_start
print(f"Total NaN-skipped batches: {nan_skipped}")
return history
@@ -662,7 +684,7 @@ def train_ddpm(
loader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=True,
num_workers=min(4, os.cpu_count() or 1),
num_workers=cfg.get("num_workers", min(4, os.cpu_count() or 1)),
pin_memory=(device.type == "cuda"), drop_last=True,
)
opt = torch.optim.AdamW(model.parameters(), lr=lr)
@@ -679,7 +701,8 @@ def train_ddpm(
save_dir.mkdir(parents=True, exist_ok=True)
samples_dir = save_dir.parent / "samples" / run_name
fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device))
fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device),
num_workers=cfg.get("num_workers", min(4, os.cpu_count() or 1)))
history = {"loss": [], "fid": {}}
best_fid = float("inf")