Testing VAE until it works - v1
This commit is contained in:
@@ -393,7 +393,6 @@ def _save_vae_samples(
|
||||
# Interleave real / reconstruction pairs
|
||||
pairs = torch.stack([real, recon], dim=1).flatten(0, 1)
|
||||
save_image(pairs, samples_dir / f"epoch_{epoch:04d}_recon.png", nrow=4)
|
||||
vae.train()
|
||||
|
||||
|
||||
def train_vae(
|
||||
@@ -549,18 +548,20 @@ def train_vae(
|
||||
# ── PatchGAN discriminator step ───────────────────────────────
|
||||
adv_d = real.new_zeros(1).squeeze()
|
||||
if use_adversarial:
|
||||
opt_d.zero_grad()
|
||||
d_real = patchgan(real)
|
||||
d_fake = patchgan(recon.detach())
|
||||
adv_d = hinge_d_loss(d_real, d_fake)
|
||||
if torch.isfinite(adv_d):
|
||||
adv_d.backward()
|
||||
torch.nn.utils.clip_grad_norm_(patchgan.parameters(), grad_clip)
|
||||
opt_d.step()
|
||||
# Warmup: only start adversarial training after 20% of epochs
|
||||
if epoch > kl_warmup_epochs:
|
||||
opt_d.zero_grad()
|
||||
d_real = patchgan(real)
|
||||
d_fake = patchgan(recon.detach())
|
||||
adv_d = hinge_d_loss(d_real, d_fake)
|
||||
if torch.isfinite(adv_d):
|
||||
adv_d.backward()
|
||||
torch.nn.utils.clip_grad_norm_(patchgan.parameters(), grad_clip)
|
||||
opt_d.step()
|
||||
|
||||
# ── PatchGAN generator adversarial loss ───────────────────────
|
||||
adv_g = real.new_zeros(1).squeeze()
|
||||
if use_adversarial:
|
||||
if use_adversarial and epoch > kl_warmup_epochs:
|
||||
adv_g = hinge_g_loss(patchgan(recon))
|
||||
vae_loss = vae_loss + lambda_adversarial * adv_g
|
||||
|
||||
|
||||
Reference in New Issue
Block a user