Testing VAE until it works - v1
This commit is contained in:
@@ -5,6 +5,6 @@
|
|||||||
"lr_d": 1e-4,
|
"lr_d": 1e-4,
|
||||||
"beta_kl": 0.25,
|
"beta_kl": 0.25,
|
||||||
"lambda_perceptual": 0.1,
|
"lambda_perceptual": 0.1,
|
||||||
"lambda_adversarial": 0.1,
|
"lambda_adversarial": 0.01,
|
||||||
"ndf_patch": 64
|
"ndf_patch": 64
|
||||||
}
|
}
|
||||||
|
|||||||
+1
-1
@@ -52,7 +52,7 @@ def main(config_path, *, data_dir_override=None, output_root="generator/outputs"
|
|||||||
|
|
||||||
# Count total trainable parameters
|
# Count total trainable parameters
|
||||||
if isinstance(model, tuple):
|
if isinstance(model, tuple):
|
||||||
n_params = sum(p.numel() for p in model[0].parameters() if p.requires_grad)
|
n_params = sum(p.numel() for m in model for p in m.parameters() if p.requires_grad)
|
||||||
else:
|
else:
|
||||||
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||||
print(f"Trainable params: {n_params:,}")
|
print(f"Trainable params: {n_params:,}")
|
||||||
|
|||||||
@@ -393,7 +393,6 @@ def _save_vae_samples(
|
|||||||
# Interleave real / reconstruction pairs
|
# Interleave real / reconstruction pairs
|
||||||
pairs = torch.stack([real, recon], dim=1).flatten(0, 1)
|
pairs = torch.stack([real, recon], dim=1).flatten(0, 1)
|
||||||
save_image(pairs, samples_dir / f"epoch_{epoch:04d}_recon.png", nrow=4)
|
save_image(pairs, samples_dir / f"epoch_{epoch:04d}_recon.png", nrow=4)
|
||||||
vae.train()
|
|
||||||
|
|
||||||
|
|
||||||
def train_vae(
|
def train_vae(
|
||||||
@@ -549,6 +548,8 @@ def train_vae(
|
|||||||
# ── PatchGAN discriminator step ───────────────────────────────
|
# ── PatchGAN discriminator step ───────────────────────────────
|
||||||
adv_d = real.new_zeros(1).squeeze()
|
adv_d = real.new_zeros(1).squeeze()
|
||||||
if use_adversarial:
|
if use_adversarial:
|
||||||
|
# Warmup: only start adversarial training after 20% of epochs
|
||||||
|
if epoch > kl_warmup_epochs:
|
||||||
opt_d.zero_grad()
|
opt_d.zero_grad()
|
||||||
d_real = patchgan(real)
|
d_real = patchgan(real)
|
||||||
d_fake = patchgan(recon.detach())
|
d_fake = patchgan(recon.detach())
|
||||||
@@ -560,7 +561,7 @@ def train_vae(
|
|||||||
|
|
||||||
# ── PatchGAN generator adversarial loss ───────────────────────
|
# ── PatchGAN generator adversarial loss ───────────────────────
|
||||||
adv_g = real.new_zeros(1).squeeze()
|
adv_g = real.new_zeros(1).squeeze()
|
||||||
if use_adversarial:
|
if use_adversarial and epoch > kl_warmup_epochs:
|
||||||
adv_g = hinge_g_loss(patchgan(recon))
|
adv_g = hinge_g_loss(patchgan(recon))
|
||||||
vae_loss = vae_loss + lambda_adversarial * adv_g
|
vae_loss = vae_loss + lambda_adversarial * adv_g
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user