import os import time from pathlib import Path import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader from torchvision.utils import save_image from tqdm import tqdm from src.training.ema import EMA from src.training.fid import FIDEvaluator if hasattr(torch.amp, "GradScaler"): _GradScaler = torch.amp.GradScaler _autocast = torch.amp.autocast else: from torch.cuda.amp import GradScaler as _GS, autocast as _AC _GradScaler = lambda device="", enabled=True, **kw: _GS(**kw) _autocast = lambda device_type="", enabled=True, **kw: _AC(enabled=enabled, **kw) def _save_samples(generator_ema, samples_dir: Path, epoch: int, *, fixed_noise: torch.Tensor, device) -> None: samples_dir.mkdir(parents=True, exist_ok=True) with torch.no_grad(): imgs = generator_ema.model(fixed_noise.to(device)) # EMA model, [-1, 1] imgs = (imgs.clamp(-1, 1) + 1.0) / 2.0 # -> [0, 1] save_image(imgs, samples_dir / f"epoch_{epoch:04d}.png", nrow=4) def train_dcgan( generator, discriminator, train_dataset, cfg: dict, *, save_dir, run_name: str, device: str = "cuda", ) -> dict: """Vanilla DCGAN training loop with BCE loss (Radford et al., 2015). Used as the Phase 1 baseline for cheap pipeline ablations. No gradient penalty, no n_critic, single G/D step per batch. """ device = torch.device(device if torch.cuda.is_available() else "cpu") generator = generator.to(device) discriminator = discriminator.to(device) n_g = sum(p.numel() for p in generator.parameters() if p.requires_grad) n_d = sum(p.numel() for p in discriminator.parameters() if p.requires_grad) print(f"Generator: {n_g:,} params Discriminator: {n_d:,} params") epochs = cfg["epochs"] batch_size = cfg["batch_size"] lr_g = cfg.get("lr_g", 2e-4) lr_d = cfg.get("lr_d", 2e-4) beta1 = cfg.get("beta1", 0.5) beta2 = cfg.get("beta2", 0.999) latent_dim = cfg.get("latent_dim", 100) ema_decay = cfg.get("ema_decay", 0.9999) sample_interval = cfg.get("sample_interval", 10) fid_interval = cfg.get("fid_interval", 25) fid_n_real = cfg.get("fid_n_real", 5000) loader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=cfg.get("num_workers", min(4, os.cpu_count() or 1)), pin_memory=(device.type == "cuda"), drop_last=True, ) opt_g = torch.optim.Adam(generator.parameters(), lr=lr_g, betas=(beta1, beta2)) opt_d = torch.optim.Adam(discriminator.parameters(), lr=lr_d, betas=(beta1, beta2)) bce = nn.BCEWithLogitsLoss() use_amp = device.type == "cuda" scaler_g = _GradScaler("cuda", enabled=use_amp) scaler_d = _GradScaler("cuda", enabled=use_amp) ema = EMA(generator, decay=ema_decay) # Fixed noise for consistent sample tracking across epochs fixed_noise = torch.randn(16, latent_dim, 1, 1, device=device) save_dir = Path(save_dir) save_dir.mkdir(parents=True, exist_ok=True) samples_dir = save_dir.parent / "samples" / run_name fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device), num_workers=cfg.get("num_workers", min(4, os.cpu_count() or 1))) history = {"g_loss": [], "d_loss": [], "d_real": [], "d_fake": [], "fid": {}} best_fid = float("inf") print(f"Device: {device} AMP: {use_amp} Batches/epoch: {len(loader)}") # Linear LR decay from epoch epochs//2 to epochs decay_start = epochs // 2 sched_g = torch.optim.lr_scheduler.LambdaLR( opt_g, lr_lambda=lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / (epochs - decay_start))) sched_d = torch.optim.lr_scheduler.LambdaLR( opt_d, lr_lambda=lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / (epochs - decay_start))) t_start = time.time() for epoch in range(1, epochs + 1): generator.train() discriminator.train() g_sum = d_sum = real_sum = fake_sum = 0.0 n_batches = 0 for imgs in tqdm(loader, desc=f"Epoch {epoch}/{epochs}", leave=False): imgs = imgs.to(device) bsz = imgs.size(0) real_labels = torch.ones(bsz, device=device) fake_labels = torch.zeros(bsz, device=device) # ── Discriminator step ──────────────────────────────────────── noise = torch.randn(bsz, latent_dim, 1, 1, device=device) with _autocast("cuda", enabled=use_amp): fake = generator(noise).detach() d_real = discriminator(imgs) d_fake = discriminator(fake) d_loss = bce(d_real, real_labels) + bce(d_fake, fake_labels) opt_d.zero_grad() scaler_d.scale(d_loss).backward() scaler_d.step(opt_d) scaler_d.update() # ── Generator step ──────────────────────────────────────────── noise = torch.randn(bsz, latent_dim, 1, 1, device=device) with _autocast("cuda", enabled=use_amp): g_loss = bce(discriminator(generator(noise)), real_labels) opt_g.zero_grad() scaler_g.scale(g_loss).backward() scaler_g.step(opt_g) scaler_g.update() ema.update(generator) g_sum += g_loss.item() d_sum += d_loss.item() real_sum += d_real.mean().item() fake_sum += d_fake.mean().item() n_batches += 1 avg_g = g_sum / n_batches avg_d = d_sum / n_batches avg_r = real_sum / n_batches avg_f = fake_sum / n_batches history["g_loss"].append(avg_g) history["d_loss"].append(avg_d) history["d_real"].append(avg_r) history["d_fake"].append(avg_f) print( f"[{epoch:03d}/{epochs}] " f"G: {avg_g:.4f} D: {avg_d:.4f} D(real): {avg_r:.4f} D(fake): {avg_f:.4f}" ) if epoch % sample_interval == 0: _save_samples(ema, samples_dir, epoch, fixed_noise=fixed_noise, device=device) if epoch % fid_interval == 0: ema.model.eval() with torch.no_grad(): fake_imgs = torch.cat([ ema.model(torch.randn(64, latent_dim, 1, 1, device=device)) for _ in range(fid_n_real // 64 + 1) ])[:fid_n_real] fid_score = fid_eval.compute(fake_imgs) history["fid"][epoch] = fid_score print(f" FID @ epoch {epoch}: {fid_score:.2f}") if fid_score < best_fid: best_fid = fid_score torch.save(generator.state_dict(), save_dir / f"{run_name}_best_g.pt") torch.save(ema.model.state_dict(), save_dir / f"{run_name}_best_ema.pt") sched_g.step() sched_d.step() torch.save(generator.state_dict(), save_dir / f"{run_name}_final_g.pt") torch.save(discriminator.state_dict(), save_dir / f"{run_name}_final_d.pt") torch.save(ema.model.state_dict(), save_dir / f"{run_name}_final_ema.pt") history["train_time_s"] = time.time() - t_start return history def _gradient_penalty(critic, real: torch.Tensor, fake: torch.Tensor, device) -> torch.Tensor: """Two-sided gradient penalty (Gulrajani et al., 2017).""" bsz = real.size(0) eps = torch.rand(bsz, 1, 1, 1, device=device) interp = (eps * real + (1.0 - eps) * fake).requires_grad_(True) d_interp = critic(interp) grad = torch.autograd.grad( outputs=d_interp, inputs=interp, grad_outputs=torch.ones_like(d_interp), create_graph=True, retain_graph=True, )[0] return ((grad.norm(2, dim=[1, 2, 3]) - 1) ** 2).mean() def train_wgan( generator, critic, train_dataset, cfg: dict, *, save_dir, run_name: str, device: str = "cuda", ) -> dict: """WGAN-GP training loop (Gulrajani et al., 2017). Used for Phase 2.2–2.4. Gradient penalty replaces weight clipping. The critic runs in float32 to keep GP gradient computation numerically stable; AMP is used only for the generator forward/backward. """ device = torch.device(device if torch.cuda.is_available() else "cpu") generator = generator.to(device) critic = critic.to(device) n_g = sum(p.numel() for p in generator.parameters() if p.requires_grad) n_c = sum(p.numel() for p in critic.parameters() if p.requires_grad) print(f"Generator: {n_g:,} params Critic: {n_c:,} params") epochs = cfg["epochs"] batch_size = cfg["batch_size"] lr_g = cfg.get("lr_g", 1e-4) lr_d = cfg.get("lr_d", 1e-4) beta1 = cfg.get("beta1", 0.0) beta2 = cfg.get("beta2", 0.9) latent_dim = cfg.get("latent_dim", 128) n_critic = cfg.get("n_critic", 5) gp_lambda = cfg.get("gp_lambda", 10) ema_decay = cfg.get("ema_decay", 0.9999) sample_interval = cfg.get("sample_interval", 10) fid_interval = cfg.get("fid_interval", 25) fid_n_real = cfg.get("fid_n_real", 5000) loader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=cfg.get("num_workers", min(4, os.cpu_count() or 1)), pin_memory=(device.type == "cuda"), drop_last=True, ) opt_g = torch.optim.Adam(generator.parameters(), lr=lr_g, betas=(beta1, beta2)) opt_c = torch.optim.Adam(critic.parameters(), lr=lr_d, betas=(beta1, beta2)) use_amp = device.type == "cuda" scaler_g = _GradScaler("cuda", enabled=use_amp) ema = EMA(generator, decay=ema_decay) # Fixed noise for consistent sample tracking across epochs fixed_noise = torch.randn(16, latent_dim, 1, 1, device=device) save_dir = Path(save_dir) save_dir.mkdir(parents=True, exist_ok=True) samples_dir = save_dir.parent / "samples" / run_name fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device), num_workers=cfg.get("num_workers", min(4, os.cpu_count() or 1))) history = {"g_loss": [], "w_dist": [], "gp": [], "d_real": [], "d_fake": [], "fid": {}} best_fid = float("inf") print(f"Device: {device} AMP (G only): {use_amp} Batches/epoch: {len(loader)} n_critic: {n_critic}") # Linear LR decay from epoch epochs//2 to epochs decay_start = epochs // 2 sched_g = torch.optim.lr_scheduler.LambdaLR( opt_g, lr_lambda=lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / (epochs - decay_start))) sched_c = torch.optim.lr_scheduler.LambdaLR( opt_c, lr_lambda=lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / (epochs - decay_start))) t_start = time.time() for epoch in range(1, epochs + 1): generator.train() critic.train() g_sum = w_sum = gp_sum = real_sum = fake_sum = 0.0 n_c_steps = n_g_steps = 0 for batch_idx, real in enumerate(tqdm(loader, desc=f"Epoch {epoch}/{epochs}", leave=False)): real = real.to(device) bsz = real.size(0) # ── Critic step (every batch) ───────────────────────────────── # Run critic in float32 — GP requires double-precision gradients # and AMP can degrade stability here. opt_c.zero_grad() with torch.no_grad(): fake = generator(torch.randn(bsz, latent_dim, 1, 1, device=device)) real_f32 = real.float() fake_f32 = fake.float().detach() d_real = critic(real_f32) d_fake = critic(fake_f32) gp = _gradient_penalty(critic, real_f32, fake_f32.detach(), device) c_loss = d_fake.mean() - d_real.mean() + gp_lambda * gp c_loss.backward() opt_c.step() w_dist = (d_real.mean() - d_fake.mean()).item() w_sum += w_dist gp_sum += gp.item() real_sum += d_real.mean().item() fake_sum += d_fake.mean().item() n_c_steps += 1 # ── Generator step (every n_critic batches) ─────────────────── if (batch_idx + 1) % n_critic == 0: opt_g.zero_grad() with _autocast("cuda", enabled=use_amp): fake = generator(torch.randn(bsz, latent_dim, 1, 1, device=device)) g_loss = -critic(fake.float()).mean() scaler_g.scale(g_loss).backward() scaler_g.step(opt_g) scaler_g.update() ema.update(generator) g_sum += g_loss.item() n_g_steps += 1 avg_w = w_sum / max(n_c_steps, 1) avg_gp = gp_sum / max(n_c_steps, 1) avg_g = g_sum / max(n_g_steps, 1) avg_r = real_sum / max(n_c_steps, 1) avg_f = fake_sum / max(n_c_steps, 1) history["g_loss"].append(avg_g) history["w_dist"].append(avg_w) history["gp"].append(avg_gp) history["d_real"].append(avg_r) history["d_fake"].append(avg_f) print( f"[{epoch:03d}/{epochs}] " f"G: {avg_g:.4f} W-dist: {avg_w:.4f} GP: {avg_gp:.4f} " f"C(real): {avg_r:.4f} C(fake): {avg_f:.4f}" ) if epoch % sample_interval == 0: _save_samples(ema, samples_dir, epoch, fixed_noise=fixed_noise, device=device) if epoch % fid_interval == 0: ema.model.eval() with torch.no_grad(): fake_imgs = torch.cat([ ema.model(torch.randn(64, latent_dim, 1, 1, device=device)) for _ in range(fid_n_real // 64 + 1) ])[:fid_n_real] fid_score = fid_eval.compute(fake_imgs) history["fid"][epoch] = fid_score print(f" FID @ epoch {epoch}: {fid_score:.2f}") if fid_score < best_fid: best_fid = fid_score torch.save(generator.state_dict(), save_dir / f"{run_name}_best_g.pt") torch.save(ema.model.state_dict(), save_dir / f"{run_name}_best_ema.pt") sched_g.step() sched_c.step() torch.save(generator.state_dict(), save_dir / f"{run_name}_final_g.pt") torch.save(critic.state_dict(), save_dir / f"{run_name}_final_d.pt") torch.save(ema.model.state_dict(), save_dir / f"{run_name}_final_ema.pt") history["train_time_s"] = time.time() - t_start return history # ──────────────────────────────────────────────────────────────────────────── # Phase 3 — VAE (3.1 MSE+KL · 3.2 +Perceptual · 3.3 +PatchGAN) # ──────────────────────────────────────────────────────────────────────────── def _save_vae_samples( vae, samples_dir: Path, epoch: int, *, fixed_z: torch.Tensor, fixed_real: torch.Tensor, device, ) -> None: """Save prior samples and a real-vs-reconstruction grid side by side.""" samples_dir.mkdir(parents=True, exist_ok=True) vae.eval() with torch.no_grad(): prior = vae.decode(fixed_z.to(device)) prior = (prior.clamp(-1, 1) + 1.0) / 2.0 save_image(prior, samples_dir / f"epoch_{epoch:04d}.png", nrow=4) recon, _, _ = vae(fixed_real.to(device)) recon = (recon.clamp(-1, 1) + 1.0) / 2.0 real = (fixed_real.to(device) + 1.0) / 2.0 # Interleave real / reconstruction pairs pairs = torch.stack([real, recon], dim=1).flatten(0, 1) save_image(pairs, samples_dir / f"epoch_{epoch:04d}_recon.png", nrow=4) vae.train() def train_vae( vae, train_dataset, cfg: dict, *, save_dir, run_name: str, device: str = "cuda", ) -> dict: """VAE training loop covering Phase 3.1 – 3.3 and Phase 5. Config toggles: lambda_perceptual > 0 → VGG-16 perceptual loss (Phase 3.2+) lambda_adversarial > 0 → PatchGAN hinge adversarial loss (Phase 3.3) Loss: L = L_mse + λ_perc·L_vgg + λ_adv·L_adv + β_kl·L_kl KL is computed as mean over latent dimensions (scale-invariant), so beta_kl is comparable across different latent_dim values. AMP is intentionally disabled for VAE training — mixed-precision float16 overflows when the KL divergence spikes, producing NaN cascades that corrupt the model irrecoverably. All VAE + perceptual + PatchGAN computation runs in float32. """ device = torch.device(device if torch.cuda.is_available() else "cpu") vae = vae.to(device) n_vae = sum(p.numel() for p in vae.parameters() if p.requires_grad) print(f"VAE: {n_vae:,} params") epochs = cfg["epochs"] batch_size = cfg["batch_size"] lr = cfg.get("lr", 1e-3) latent_dim = cfg.get("latent_dim", 256) beta_kl = cfg.get("beta_kl", 1.0) lambda_perceptual = cfg.get("lambda_perceptual", 0.0) lambda_adversarial = cfg.get("lambda_adversarial", 0.0) lr_d = cfg.get("lr_d", 1e-4) grad_clip = cfg.get("grad_clip", 1.0) ema_decay = cfg.get("ema_decay", 0.9999) sample_interval = cfg.get("sample_interval", 10) fid_interval = cfg.get("fid_interval", 25) fid_n_real = cfg.get("fid_n_real", 5000) use_perceptual = lambda_perceptual > 0 use_adversarial = lambda_adversarial > 0 loader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=cfg.get("num_workers", min(4, os.cpu_count() or 1)), pin_memory=(device.type == "cuda"), drop_last=True, ) opt_vae = torch.optim.Adam(vae.parameters(), lr=lr) # AMP disabled — float16 overflows on KL spikes, causing NaN cascades use_amp = False # KL warmup: linearly ramp beta_kl from 0 to target over first 20% of training kl_warmup_epochs = max(1, epochs // 5) # Linear LR decay from epoch epochs//2 to epochs decay_start = epochs // 2 sched_vae = torch.optim.lr_scheduler.LambdaLR( opt_vae, lr_lambda=lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / max(epochs - decay_start, 1))) sched_d = None # set below if adversarial # ── Optional components ─────────────────────────────────────────────── perc_fn = None patchgan = None opt_d = None if use_perceptual: from src.training.perceptual import PerceptualLoss perc_fn = PerceptualLoss().to(device).float() print("Perceptual loss: VGG-16 relu1_2 + relu2_2 + relu3_3") if use_adversarial: from src.models.patchgan import PatchGANDiscriminator, hinge_d_loss, hinge_g_loss patchgan = PatchGANDiscriminator( ndf=cfg.get("ndf_patch", 64), image_size=cfg.get("image_size", 64), ).to(device).float() opt_d = torch.optim.Adam(patchgan.parameters(), lr=lr_d, betas=(0.5, 0.999)) sched_d = torch.optim.lr_scheduler.LambdaLR( opt_d, lr_lambda=lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / max(epochs - decay_start, 1))) n_d = sum(p.numel() for p in patchgan.parameters()) print(f"PatchGAN: {n_d:,} params") else: hinge_d_loss = hinge_g_loss = None # never called # ── Fixed seeds for consistent visualisation ────────────────────────── fixed_z = torch.randn(16, latent_dim, device=device) # Grab first 16 real images from the loader for reconstruction tracking _it = iter(loader) fixed_real = next(_it)[:16].cpu() ema = EMA(vae, decay=ema_decay) save_dir = Path(save_dir) save_dir.mkdir(parents=True, exist_ok=True) samples_dir = save_dir.parent / "samples" / run_name fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device), num_workers=cfg.get("num_workers", min(4, os.cpu_count() or 1))) history = { "recon_loss": [], "kl_loss": [], "perc_loss": [], "adv_g_loss": [], "adv_d_loss": [], "fid": {}, } best_fid = float("inf") nan_skipped = 0 print( f"Device: {device} AMP: disabled (float32) Batches/epoch: {len(loader)}" f" β_kl={beta_kl} (warmup {kl_warmup_epochs}ep) λ_perc={lambda_perceptual}" f" λ_adv={lambda_adversarial}" ) t_start = time.time() for epoch in range(1, epochs + 1): vae.train() if patchgan is not None: patchgan.train() recon_sum = kl_sum = perc_sum = adv_g_sum = adv_d_sum = 0.0 n_batches = 0 for real in tqdm(loader, desc=f"Epoch {epoch}/{epochs}", leave=False): real = real.to(device).float() # KL warmup: ramp from 0 to beta_kl over kl_warmup_epochs current_beta = beta_kl * min(1.0, epoch / kl_warmup_epochs) # ── VAE forward (float32, no AMP) ──────────────────────────── recon, mu, log_var = vae(real) mse = F.mse_loss(recon, real) # KL divergence: mean over latent dims (scale-invariant w.r.t. latent_dim) kl = (-0.5 * (1 + log_var - mu.pow(2) - log_var.exp())).mean() perc = perc_fn(recon, real) if use_perceptual else real.new_zeros(1).squeeze() vae_loss = mse + current_beta * kl + lambda_perceptual * perc # ── NaN/Inf guard ──────────────────────────────────────────── if not torch.isfinite(vae_loss): nan_skipped += 1 opt_vae.zero_grad() continue # ── PatchGAN discriminator step ─────────────────────────────── adv_d = real.new_zeros(1).squeeze() if use_adversarial: opt_d.zero_grad() d_real = patchgan(real) d_fake = patchgan(recon.detach()) adv_d = hinge_d_loss(d_real, d_fake) if torch.isfinite(adv_d): adv_d.backward() torch.nn.utils.clip_grad_norm_(patchgan.parameters(), grad_clip) opt_d.step() # ── PatchGAN generator adversarial loss ─────────────────────── adv_g = real.new_zeros(1).squeeze() if use_adversarial: adv_g = hinge_g_loss(patchgan(recon)) vae_loss = vae_loss + lambda_adversarial * adv_g # ── VAE backward ────────────────────────────────────────────── opt_vae.zero_grad() vae_loss.backward() torch.nn.utils.clip_grad_norm_(vae.parameters(), grad_clip) opt_vae.step() ema.update(vae) recon_sum += mse.item() kl_sum += kl.item() perc_sum += perc.item() adv_g_sum += adv_g.item() adv_d_sum += adv_d.item() n_batches += 1 avg_r = recon_sum / max(n_batches, 1) avg_k = kl_sum / max(n_batches, 1) avg_p = perc_sum / max(n_batches, 1) avg_g = adv_g_sum / max(n_batches, 1) avg_d = adv_d_sum / max(n_batches, 1) history["recon_loss"].append(avg_r) history["kl_loss"].append(avg_k) history["perc_loss"].append(avg_p) history["adv_g_loss"].append(avg_g) history["adv_d_loss"].append(avg_d) print( f"[{epoch:03d}/{epochs}] " f"MSE: {avg_r:.4f} KL: {avg_k:.2f} β={current_beta:.6f} " f"Perc: {avg_p:.4f} AdvG: {avg_g:.4f} AdvD: {avg_d:.4f}" f" (NaN skipped: {nan_skipped})" ) if epoch % sample_interval == 0: _save_vae_samples( ema.model, samples_dir, epoch, fixed_z=fixed_z, fixed_real=fixed_real, device=device, ) if epoch % fid_interval == 0: ema.model.eval() with torch.no_grad(): fake_imgs = torch.cat([ ema.model.sample(64, device) for _ in range(fid_n_real // 64 + 1) ])[:fid_n_real] fid_score = fid_eval.compute(fake_imgs) history["fid"][epoch] = fid_score print(f" FID @ epoch {epoch}: {fid_score:.2f}") if fid_score < best_fid: best_fid = fid_score torch.save(vae.state_dict(), save_dir / f"{run_name}_best_vae.pt") torch.save(ema.model.state_dict(), save_dir / f"{run_name}_best_ema.pt") sched_vae.step() if sched_d is not None: sched_d.step() torch.save(vae.state_dict(), save_dir / f"{run_name}_final_vae.pt") torch.save(ema.model.state_dict(), save_dir / f"{run_name}_final_ema.pt") if patchgan is not None: torch.save(patchgan.state_dict(), save_dir / f"{run_name}_final_patchgan.pt") history["train_time_s"] = time.time() - t_start print(f"Total NaN-skipped batches: {nan_skipped}") return history # ──────────────────────────────────────────────────────────────────────────── # Phase 4 — DDPM (4.1 linear·ε · 4.2 cosine·ε · 4.3 cosine·v · 4.4 wider) # ──────────────────────────────────────────────────────────────────────────── def train_ddpm( model, train_dataset, cfg: dict, *, save_dir, run_name: str, device: str = "cuda", ) -> dict: """DDPM training loop (Ho et al., 2020) covering Phase 4.1 – 4.4. Config keys: noise_schedule — "linear" (4.1) or "cosine" (4.2+) pred_type — "eps" (4.1–4.2) or "v" (4.3+) T — diffusion timesteps (default 1000) base_ch / ch_mult / attn_resolutions — U-Net capacity (see unet.py) ddim_steps — DDIM steps for FID evaluation (default 100) """ from src.training.diffusion import ( linear_betas, cosine_betas, make_alpha_bars, diffusion_loss, ddim_sample, ) device = torch.device(device if torch.cuda.is_available() else "cpu") model = model.to(device) n_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f"U-Net: {n_params:,} params") epochs = cfg["epochs"] batch_size = cfg["batch_size"] lr = cfg.get("lr", 2e-4) T = cfg.get("T", 1000) noise_schedule = cfg.get("noise_schedule", "linear") pred_type = cfg.get("pred_type", "eps") ddim_steps = cfg.get("ddim_steps", 100) image_size = cfg.get("image_size", 64) ema_decay = cfg.get("ema_decay", 0.9999) sample_interval = cfg.get("sample_interval", 10) fid_interval = cfg.get("fid_interval", 25) fid_n_real = cfg.get("fid_n_real", 5000) # Build noise schedule and register on device betas = (cosine_betas(T) if noise_schedule == "cosine" else linear_betas(T)).to(device) alpha_bars = make_alpha_bars(betas) # on device loader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=cfg.get("num_workers", min(4, os.cpu_count() or 1)), pin_memory=(device.type == "cuda"), drop_last=True, ) opt = torch.optim.AdamW(model.parameters(), lr=lr) use_amp = device.type == "cuda" scaler = _GradScaler("cuda", enabled=use_amp) ema = EMA(model, decay=ema_decay) # Fixed noise for sample visualisation (same latents across epochs) fixed_noise = torch.randn(16, 3, image_size, image_size, device=device) save_dir = Path(save_dir) save_dir.mkdir(parents=True, exist_ok=True) samples_dir = save_dir.parent / "samples" / run_name fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device), num_workers=cfg.get("num_workers", min(4, os.cpu_count() or 1))) history = {"loss": [], "fid": {}} best_fid = float("inf") print( f"Device: {device} AMP: {use_amp} Batches/epoch: {len(loader)}" f" T={T} schedule={noise_schedule} pred={pred_type} ddim_steps={ddim_steps}" ) # Linear LR decay from epoch epochs//2 to epochs decay_start = epochs // 2 sched = torch.optim.lr_scheduler.LambdaLR( opt, lr_lambda=lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / max(epochs - decay_start, 1))) t_start = time.time() for epoch in range(1, epochs + 1): model.train() loss_sum = 0.0 n_batches = 0 for x0 in tqdm(loader, desc=f"Epoch {epoch}/{epochs}", leave=False): x0 = x0.to(device) t = torch.randint(0, T, (x0.size(0),), device=device) with _autocast("cuda", enabled=use_amp): loss = diffusion_loss(model, x0, t, alpha_bars, pred_type) opt.zero_grad() scaler.scale(loss).backward() scaler.unscale_(opt) torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) scaler.step(opt) scaler.update() ema.update(model) loss_sum += loss.item() n_batches += 1 avg_loss = loss_sum / n_batches history["loss"].append(avg_loss) print(f"[{epoch:03d}/{epochs}] Loss: {avg_loss:.5f}") if epoch % sample_interval == 0: samples_dir.mkdir(parents=True, exist_ok=True) ema.model.eval() with torch.no_grad(): # Quick visualisation: denoise fixed_noise via DDIM imgs = ddim_sample( ema.model, 16, image_size, alpha_bars, n_steps=50, pred_type=pred_type, device=str(device), batch_size=16, ) imgs = (imgs.clamp(-1, 1) + 1.0) / 2.0 save_image(imgs, samples_dir / f"epoch_{epoch:04d}.png", nrow=4) if epoch % fid_interval == 0: ema.model.eval() fake_imgs = ddim_sample( ema.model, fid_n_real, image_size, alpha_bars, n_steps=ddim_steps, pred_type=pred_type, device=str(device), batch_size=32, ) fid_score = fid_eval.compute(fake_imgs) history["fid"][epoch] = fid_score print(f" FID @ epoch {epoch}: {fid_score:.2f}") if fid_score < best_fid: best_fid = fid_score torch.save(model.state_dict(), save_dir / f"{run_name}_best_unet.pt") torch.save(ema.model.state_dict(), save_dir / f"{run_name}_best_ema.pt") sched.step() torch.save(model.state_dict(), save_dir / f"{run_name}_final_unet.pt") torch.save(ema.model.state_dict(), save_dir / f"{run_name}_final_ema.pt") history["train_time_s"] = time.time() - t_start return history