Testing VAE until it works - v1

This commit is contained in:
Johnny Fernandes
2026-05-02 18:38:55 +01:00
parent 6c1e939803
commit 02b1f7f16f
3 changed files with 13 additions and 12 deletions
@@ -5,6 +5,6 @@
"lr_d": 1e-4, "lr_d": 1e-4,
"beta_kl": 0.25, "beta_kl": 0.25,
"lambda_perceptual": 0.1, "lambda_perceptual": 0.1,
"lambda_adversarial": 0.1, "lambda_adversarial": 0.01,
"ndf_patch": 64 "ndf_patch": 64
} }
+1 -1
View File
@@ -52,7 +52,7 @@ def main(config_path, *, data_dir_override=None, output_root="generator/outputs"
# Count total trainable parameters # Count total trainable parameters
if isinstance(model, tuple): if isinstance(model, tuple):
n_params = sum(p.numel() for p in model[0].parameters() if p.requires_grad) n_params = sum(p.numel() for m in model for p in m.parameters() if p.requires_grad)
else: else:
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad) n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Trainable params: {n_params:,}") print(f"Trainable params: {n_params:,}")
+11 -10
View File
@@ -393,7 +393,6 @@ def _save_vae_samples(
# Interleave real / reconstruction pairs # Interleave real / reconstruction pairs
pairs = torch.stack([real, recon], dim=1).flatten(0, 1) pairs = torch.stack([real, recon], dim=1).flatten(0, 1)
save_image(pairs, samples_dir / f"epoch_{epoch:04d}_recon.png", nrow=4) save_image(pairs, samples_dir / f"epoch_{epoch:04d}_recon.png", nrow=4)
vae.train()
def train_vae( def train_vae(
@@ -549,18 +548,20 @@ def train_vae(
# ── PatchGAN discriminator step ─────────────────────────────── # ── PatchGAN discriminator step ───────────────────────────────
adv_d = real.new_zeros(1).squeeze() adv_d = real.new_zeros(1).squeeze()
if use_adversarial: if use_adversarial:
opt_d.zero_grad() # Warmup: only start adversarial training after 20% of epochs
d_real = patchgan(real) if epoch > kl_warmup_epochs:
d_fake = patchgan(recon.detach()) opt_d.zero_grad()
adv_d = hinge_d_loss(d_real, d_fake) d_real = patchgan(real)
if torch.isfinite(adv_d): d_fake = patchgan(recon.detach())
adv_d.backward() adv_d = hinge_d_loss(d_real, d_fake)
torch.nn.utils.clip_grad_norm_(patchgan.parameters(), grad_clip) if torch.isfinite(adv_d):
opt_d.step() adv_d.backward()
torch.nn.utils.clip_grad_norm_(patchgan.parameters(), grad_clip)
opt_d.step()
# ── PatchGAN generator adversarial loss ─────────────────────── # ── PatchGAN generator adversarial loss ───────────────────────
adv_g = real.new_zeros(1).squeeze() adv_g = real.new_zeros(1).squeeze()
if use_adversarial: if use_adversarial and epoch > kl_warmup_epochs:
adv_g = hinge_g_loss(patchgan(recon)) adv_g = hinge_g_loss(patchgan(recon))
vae_loss = vae_loss + lambda_adversarial * adv_g vae_loss = vae_loss + lambda_adversarial * adv_g