Files
DRL_PROJ/generator/src/training/trainer.py
T
2026-05-02 13:26:39 +01:00

781 lines
31 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import os
import time
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from tqdm import tqdm
from src.training.ema import EMA
from src.training.fid import FIDEvaluator
if hasattr(torch.amp, "GradScaler"):
_GradScaler = torch.amp.GradScaler
_autocast = torch.amp.autocast
else:
from torch.cuda.amp import GradScaler as _GS, autocast as _AC
_GradScaler = lambda device="", enabled=True, **kw: _GS(**kw)
_autocast = lambda device_type="", enabled=True, **kw: _AC(enabled=enabled, **kw)
def _save_samples(generator_ema, samples_dir: Path, epoch: int, *, fixed_noise: torch.Tensor, device) -> None:
samples_dir.mkdir(parents=True, exist_ok=True)
with torch.no_grad():
imgs = generator_ema.model(fixed_noise.to(device)) # EMA model, [-1, 1]
imgs = (imgs.clamp(-1, 1) + 1.0) / 2.0 # -> [0, 1]
save_image(imgs, samples_dir / f"epoch_{epoch:04d}.png", nrow=4)
def train_dcgan(
generator,
discriminator,
train_dataset,
cfg: dict,
*,
save_dir,
run_name: str,
device: str = "cuda",
) -> dict:
"""Vanilla DCGAN training loop with BCE loss (Radford et al., 2015).
Used as the Phase 1 baseline for cheap pipeline ablations. No gradient
penalty, no n_critic, single G/D step per batch.
"""
device = torch.device(device if torch.cuda.is_available() else "cpu")
generator = generator.to(device)
discriminator = discriminator.to(device)
n_g = sum(p.numel() for p in generator.parameters() if p.requires_grad)
n_d = sum(p.numel() for p in discriminator.parameters() if p.requires_grad)
print(f"Generator: {n_g:,} params Discriminator: {n_d:,} params")
epochs = cfg["epochs"]
batch_size = cfg["batch_size"]
lr_g = cfg.get("lr_g", 2e-4)
lr_d = cfg.get("lr_d", 2e-4)
beta1 = cfg.get("beta1", 0.5)
beta2 = cfg.get("beta2", 0.999)
latent_dim = cfg.get("latent_dim", 100)
ema_decay = cfg.get("ema_decay", 0.9999)
sample_interval = cfg.get("sample_interval", 10)
fid_interval = cfg.get("fid_interval", 25)
fid_n_real = cfg.get("fid_n_real", 5000)
loader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=True,
num_workers=cfg.get("num_workers", min(4, os.cpu_count() or 1)),
pin_memory=(device.type == "cuda"), drop_last=True,
)
opt_g = torch.optim.Adam(generator.parameters(), lr=lr_g, betas=(beta1, beta2))
opt_d = torch.optim.Adam(discriminator.parameters(), lr=lr_d, betas=(beta1, beta2))
bce = nn.BCEWithLogitsLoss()
use_amp = device.type == "cuda"
scaler_g = _GradScaler("cuda", enabled=use_amp)
scaler_d = _GradScaler("cuda", enabled=use_amp)
ema = EMA(generator, decay=ema_decay)
# Fixed noise for consistent sample tracking across epochs
fixed_noise = torch.randn(16, latent_dim, 1, 1, device=device)
save_dir = Path(save_dir)
save_dir.mkdir(parents=True, exist_ok=True)
samples_dir = save_dir.parent / "samples" / run_name
fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device),
num_workers=cfg.get("num_workers", min(4, os.cpu_count() or 1)))
history = {"g_loss": [], "d_loss": [], "d_real": [], "d_fake": [], "fid": {}}
best_fid = float("inf")
print(f"Device: {device} AMP: {use_amp} Batches/epoch: {len(loader)}")
# Linear LR decay from epoch epochs//2 to epochs
decay_start = epochs // 2
sched_g = torch.optim.lr_scheduler.LambdaLR(
opt_g, lr_lambda=lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / (epochs - decay_start)))
sched_d = torch.optim.lr_scheduler.LambdaLR(
opt_d, lr_lambda=lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / (epochs - decay_start)))
t_start = time.time()
for epoch in range(1, epochs + 1):
generator.train()
discriminator.train()
g_sum = d_sum = real_sum = fake_sum = 0.0
n_batches = 0
for imgs in tqdm(loader, desc=f"Epoch {epoch}/{epochs}", leave=False):
imgs = imgs.to(device)
bsz = imgs.size(0)
real_labels = torch.ones(bsz, device=device)
fake_labels = torch.zeros(bsz, device=device)
# ── Discriminator step ────────────────────────────────────────
noise = torch.randn(bsz, latent_dim, 1, 1, device=device)
with _autocast("cuda", enabled=use_amp):
fake = generator(noise).detach()
d_real = discriminator(imgs)
d_fake = discriminator(fake)
d_loss = bce(d_real, real_labels) + bce(d_fake, fake_labels)
opt_d.zero_grad()
scaler_d.scale(d_loss).backward()
scaler_d.step(opt_d)
scaler_d.update()
# ── Generator step ────────────────────────────────────────────
noise = torch.randn(bsz, latent_dim, 1, 1, device=device)
with _autocast("cuda", enabled=use_amp):
g_loss = bce(discriminator(generator(noise)), real_labels)
opt_g.zero_grad()
scaler_g.scale(g_loss).backward()
scaler_g.step(opt_g)
scaler_g.update()
ema.update(generator)
g_sum += g_loss.item()
d_sum += d_loss.item()
real_sum += d_real.mean().item()
fake_sum += d_fake.mean().item()
n_batches += 1
avg_g = g_sum / n_batches
avg_d = d_sum / n_batches
avg_r = real_sum / n_batches
avg_f = fake_sum / n_batches
history["g_loss"].append(avg_g)
history["d_loss"].append(avg_d)
history["d_real"].append(avg_r)
history["d_fake"].append(avg_f)
print(
f"[{epoch:03d}/{epochs}] "
f"G: {avg_g:.4f} D: {avg_d:.4f} D(real): {avg_r:.4f} D(fake): {avg_f:.4f}"
)
if epoch % sample_interval == 0:
_save_samples(ema, samples_dir, epoch, fixed_noise=fixed_noise, device=device)
if epoch % fid_interval == 0:
ema.model.eval()
with torch.no_grad():
fake_imgs = torch.cat([
ema.model(torch.randn(64, latent_dim, 1, 1, device=device))
for _ in range(fid_n_real // 64 + 1)
])[:fid_n_real]
fid_score = fid_eval.compute(fake_imgs)
history["fid"][epoch] = fid_score
print(f" FID @ epoch {epoch}: {fid_score:.2f}")
if fid_score < best_fid:
best_fid = fid_score
torch.save(generator.state_dict(), save_dir / f"{run_name}_best_g.pt")
torch.save(ema.model.state_dict(), save_dir / f"{run_name}_best_ema.pt")
sched_g.step()
sched_d.step()
torch.save(generator.state_dict(), save_dir / f"{run_name}_final_g.pt")
torch.save(discriminator.state_dict(), save_dir / f"{run_name}_final_d.pt")
torch.save(ema.model.state_dict(), save_dir / f"{run_name}_final_ema.pt")
history["train_time_s"] = time.time() - t_start
return history
def _gradient_penalty(critic, real: torch.Tensor, fake: torch.Tensor, device) -> torch.Tensor:
"""Two-sided gradient penalty (Gulrajani et al., 2017)."""
bsz = real.size(0)
eps = torch.rand(bsz, 1, 1, 1, device=device)
interp = (eps * real + (1.0 - eps) * fake).requires_grad_(True)
d_interp = critic(interp)
grad = torch.autograd.grad(
outputs=d_interp,
inputs=interp,
grad_outputs=torch.ones_like(d_interp),
create_graph=True,
retain_graph=True,
)[0]
return ((grad.norm(2, dim=[1, 2, 3]) - 1) ** 2).mean()
def train_wgan(
generator,
critic,
train_dataset,
cfg: dict,
*,
save_dir,
run_name: str,
device: str = "cuda",
) -> dict:
"""WGAN-GP training loop (Gulrajani et al., 2017).
Used for Phase 2.22.4. Gradient penalty replaces weight clipping.
The critic runs in float32 to keep GP gradient computation numerically
stable; AMP is used only for the generator forward/backward.
"""
device = torch.device(device if torch.cuda.is_available() else "cpu")
generator = generator.to(device)
critic = critic.to(device)
n_g = sum(p.numel() for p in generator.parameters() if p.requires_grad)
n_c = sum(p.numel() for p in critic.parameters() if p.requires_grad)
print(f"Generator: {n_g:,} params Critic: {n_c:,} params")
epochs = cfg["epochs"]
batch_size = cfg["batch_size"]
lr_g = cfg.get("lr_g", 1e-4)
lr_d = cfg.get("lr_d", 1e-4)
beta1 = cfg.get("beta1", 0.0)
beta2 = cfg.get("beta2", 0.9)
latent_dim = cfg.get("latent_dim", 128)
n_critic = cfg.get("n_critic", 5)
gp_lambda = cfg.get("gp_lambda", 10)
ema_decay = cfg.get("ema_decay", 0.9999)
sample_interval = cfg.get("sample_interval", 10)
fid_interval = cfg.get("fid_interval", 25)
fid_n_real = cfg.get("fid_n_real", 5000)
loader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=True,
num_workers=cfg.get("num_workers", min(4, os.cpu_count() or 1)),
pin_memory=(device.type == "cuda"), drop_last=True,
)
opt_g = torch.optim.Adam(generator.parameters(), lr=lr_g, betas=(beta1, beta2))
opt_c = torch.optim.Adam(critic.parameters(), lr=lr_d, betas=(beta1, beta2))
use_amp = device.type == "cuda"
scaler_g = _GradScaler("cuda", enabled=use_amp)
ema = EMA(generator, decay=ema_decay)
# Fixed noise for consistent sample tracking across epochs
fixed_noise = torch.randn(16, latent_dim, 1, 1, device=device)
save_dir = Path(save_dir)
save_dir.mkdir(parents=True, exist_ok=True)
samples_dir = save_dir.parent / "samples" / run_name
fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device),
num_workers=cfg.get("num_workers", min(4, os.cpu_count() or 1)))
history = {"g_loss": [], "w_dist": [], "gp": [], "d_real": [], "d_fake": [], "fid": {}}
best_fid = float("inf")
print(f"Device: {device} AMP (G only): {use_amp} Batches/epoch: {len(loader)} n_critic: {n_critic}")
# Linear LR decay from epoch epochs//2 to epochs
decay_start = epochs // 2
sched_g = torch.optim.lr_scheduler.LambdaLR(
opt_g, lr_lambda=lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / (epochs - decay_start)))
sched_c = torch.optim.lr_scheduler.LambdaLR(
opt_c, lr_lambda=lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / (epochs - decay_start)))
t_start = time.time()
for epoch in range(1, epochs + 1):
generator.train()
critic.train()
g_sum = w_sum = gp_sum = real_sum = fake_sum = 0.0
n_c_steps = n_g_steps = 0
for batch_idx, real in enumerate(tqdm(loader, desc=f"Epoch {epoch}/{epochs}", leave=False)):
real = real.to(device)
bsz = real.size(0)
# ── Critic step (every batch) ─────────────────────────────────
# Run critic in float32 — GP requires double-precision gradients
# and AMP can degrade stability here.
opt_c.zero_grad()
with torch.no_grad():
fake = generator(torch.randn(bsz, latent_dim, 1, 1, device=device))
real_f32 = real.float()
fake_f32 = fake.float().detach()
d_real = critic(real_f32)
d_fake = critic(fake_f32)
gp = _gradient_penalty(critic, real_f32, fake_f32.detach(), device)
c_loss = d_fake.mean() - d_real.mean() + gp_lambda * gp
c_loss.backward()
opt_c.step()
w_dist = (d_real.mean() - d_fake.mean()).item()
w_sum += w_dist
gp_sum += gp.item()
real_sum += d_real.mean().item()
fake_sum += d_fake.mean().item()
n_c_steps += 1
# ── Generator step (every n_critic batches) ───────────────────
if (batch_idx + 1) % n_critic == 0:
opt_g.zero_grad()
with _autocast("cuda", enabled=use_amp):
fake = generator(torch.randn(bsz, latent_dim, 1, 1, device=device))
g_loss = -critic(fake.float()).mean()
scaler_g.scale(g_loss).backward()
scaler_g.step(opt_g)
scaler_g.update()
ema.update(generator)
g_sum += g_loss.item()
n_g_steps += 1
avg_w = w_sum / max(n_c_steps, 1)
avg_gp = gp_sum / max(n_c_steps, 1)
avg_g = g_sum / max(n_g_steps, 1)
avg_r = real_sum / max(n_c_steps, 1)
avg_f = fake_sum / max(n_c_steps, 1)
history["g_loss"].append(avg_g)
history["w_dist"].append(avg_w)
history["gp"].append(avg_gp)
history["d_real"].append(avg_r)
history["d_fake"].append(avg_f)
print(
f"[{epoch:03d}/{epochs}] "
f"G: {avg_g:.4f} W-dist: {avg_w:.4f} GP: {avg_gp:.4f} "
f"C(real): {avg_r:.4f} C(fake): {avg_f:.4f}"
)
if epoch % sample_interval == 0:
_save_samples(ema, samples_dir, epoch, fixed_noise=fixed_noise, device=device)
if epoch % fid_interval == 0:
ema.model.eval()
with torch.no_grad():
fake_imgs = torch.cat([
ema.model(torch.randn(64, latent_dim, 1, 1, device=device))
for _ in range(fid_n_real // 64 + 1)
])[:fid_n_real]
fid_score = fid_eval.compute(fake_imgs)
history["fid"][epoch] = fid_score
print(f" FID @ epoch {epoch}: {fid_score:.2f}")
if fid_score < best_fid:
best_fid = fid_score
torch.save(generator.state_dict(), save_dir / f"{run_name}_best_g.pt")
torch.save(ema.model.state_dict(), save_dir / f"{run_name}_best_ema.pt")
sched_g.step()
sched_c.step()
torch.save(generator.state_dict(), save_dir / f"{run_name}_final_g.pt")
torch.save(critic.state_dict(), save_dir / f"{run_name}_final_d.pt")
torch.save(ema.model.state_dict(), save_dir / f"{run_name}_final_ema.pt")
history["train_time_s"] = time.time() - t_start
return history
# ────────────────────────────────────────────────────────────────────────────
# Phase 3 — VAE (3.1 MSE+KL · 3.2 +Perceptual · 3.3 +PatchGAN)
# ────────────────────────────────────────────────────────────────────────────
def _save_vae_samples(
vae,
samples_dir: Path,
epoch: int,
*,
fixed_z: torch.Tensor,
fixed_real: torch.Tensor,
device,
) -> None:
"""Save prior samples and a real-vs-reconstruction grid side by side."""
samples_dir.mkdir(parents=True, exist_ok=True)
vae.eval()
with torch.no_grad():
prior = vae.decode(fixed_z.to(device))
prior = (prior.clamp(-1, 1) + 1.0) / 2.0
save_image(prior, samples_dir / f"epoch_{epoch:04d}.png", nrow=4)
recon, _, _ = vae(fixed_real.to(device))
recon = (recon.clamp(-1, 1) + 1.0) / 2.0
real = (fixed_real.to(device) + 1.0) / 2.0
# Interleave real / reconstruction pairs
pairs = torch.stack([real, recon], dim=1).flatten(0, 1)
save_image(pairs, samples_dir / f"epoch_{epoch:04d}_recon.png", nrow=4)
vae.train()
def train_vae(
vae,
train_dataset,
cfg: dict,
*,
save_dir,
run_name: str,
device: str = "cuda",
) -> dict:
"""VAE training loop covering Phase 3.1 3.3 and Phase 5.
Config toggles:
lambda_perceptual > 0 → VGG-16 perceptual loss (Phase 3.2+)
lambda_adversarial > 0 → PatchGAN hinge adversarial loss (Phase 3.3)
Loss: L = L_mse + λ_perc·L_vgg + λ_adv·L_adv + β_kl·L_kl
KL is computed as mean over latent dimensions (scale-invariant), so
beta_kl is comparable across different latent_dim values.
AMP is intentionally disabled for VAE training — mixed-precision float16
overflows when the KL divergence spikes, producing NaN cascades that
corrupt the model irrecoverably. All VAE + perceptual + PatchGAN
computation runs in float32.
"""
device = torch.device(device if torch.cuda.is_available() else "cpu")
vae = vae.to(device)
n_vae = sum(p.numel() for p in vae.parameters() if p.requires_grad)
print(f"VAE: {n_vae:,} params")
epochs = cfg["epochs"]
batch_size = cfg["batch_size"]
lr = cfg.get("lr", 1e-3)
latent_dim = cfg.get("latent_dim", 256)
beta_kl = cfg.get("beta_kl", 1.0)
lambda_perceptual = cfg.get("lambda_perceptual", 0.0)
lambda_adversarial = cfg.get("lambda_adversarial", 0.0)
lr_d = cfg.get("lr_d", 1e-4)
grad_clip = cfg.get("grad_clip", 1.0)
ema_decay = cfg.get("ema_decay", 0.9999)
sample_interval = cfg.get("sample_interval", 10)
fid_interval = cfg.get("fid_interval", 25)
fid_n_real = cfg.get("fid_n_real", 5000)
use_perceptual = lambda_perceptual > 0
use_adversarial = lambda_adversarial > 0
loader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=True,
num_workers=cfg.get("num_workers", min(4, os.cpu_count() or 1)),
pin_memory=(device.type == "cuda"), drop_last=True,
)
opt_vae = torch.optim.Adam(vae.parameters(), lr=lr)
# AMP disabled — float16 overflows on KL spikes, causing NaN cascades
use_amp = False
# KL warmup: linearly ramp beta_kl from 0 to target over first 20% of training
kl_warmup_epochs = max(1, epochs // 5)
# Linear LR decay from epoch epochs//2 to epochs
decay_start = epochs // 2
sched_vae = torch.optim.lr_scheduler.LambdaLR(
opt_vae, lr_lambda=lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / max(epochs - decay_start, 1)))
sched_d = None # set below if adversarial
# ── Optional components ───────────────────────────────────────────────
perc_fn = None
patchgan = None
opt_d = None
if use_perceptual:
from src.training.perceptual import PerceptualLoss
perc_fn = PerceptualLoss().to(device).float()
print("Perceptual loss: VGG-16 relu1_2 + relu2_2 + relu3_3")
if use_adversarial:
from src.models.patchgan import PatchGANDiscriminator, hinge_d_loss, hinge_g_loss
patchgan = PatchGANDiscriminator(
ndf=cfg.get("ndf_patch", 64),
image_size=cfg.get("image_size", 64),
).to(device).float()
opt_d = torch.optim.Adam(patchgan.parameters(), lr=lr_d, betas=(0.5, 0.999))
sched_d = torch.optim.lr_scheduler.LambdaLR(
opt_d, lr_lambda=lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / max(epochs - decay_start, 1)))
n_d = sum(p.numel() for p in patchgan.parameters())
print(f"PatchGAN: {n_d:,} params")
else:
hinge_d_loss = hinge_g_loss = None # never called
# ── Fixed seeds for consistent visualisation ──────────────────────────
fixed_z = torch.randn(16, latent_dim, device=device)
# Grab first 16 real images from the loader for reconstruction tracking
_it = iter(loader)
fixed_real = next(_it)[:16].cpu()
ema = EMA(vae, decay=ema_decay)
save_dir = Path(save_dir)
save_dir.mkdir(parents=True, exist_ok=True)
samples_dir = save_dir.parent / "samples" / run_name
fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device),
num_workers=cfg.get("num_workers", min(4, os.cpu_count() or 1)))
history = {
"recon_loss": [], "kl_loss": [], "perc_loss": [],
"adv_g_loss": [], "adv_d_loss": [], "fid": {},
}
best_fid = float("inf")
nan_skipped = 0
print(
f"Device: {device} AMP: disabled (float32) Batches/epoch: {len(loader)}"
f" β_kl={beta_kl} (warmup {kl_warmup_epochs}ep) λ_perc={lambda_perceptual}"
f" λ_adv={lambda_adversarial}"
)
t_start = time.time()
for epoch in range(1, epochs + 1):
vae.train()
if patchgan is not None:
patchgan.train()
recon_sum = kl_sum = perc_sum = adv_g_sum = adv_d_sum = 0.0
n_batches = 0
for real in tqdm(loader, desc=f"Epoch {epoch}/{epochs}", leave=False):
real = real.to(device).float()
# KL warmup: ramp from 0 to beta_kl over kl_warmup_epochs
current_beta = beta_kl * min(1.0, epoch / kl_warmup_epochs)
# ── VAE forward (float32, no AMP) ────────────────────────────
recon, mu, log_var = vae(real)
mse = F.mse_loss(recon, real)
# KL divergence: mean over latent dims (scale-invariant w.r.t. latent_dim)
kl = (-0.5 * (1 + log_var - mu.pow(2) - log_var.exp())).mean()
perc = perc_fn(recon, real) if use_perceptual else real.new_zeros(1).squeeze()
vae_loss = mse + current_beta * kl + lambda_perceptual * perc
# ── NaN/Inf guard ────────────────────────────────────────────
if not torch.isfinite(vae_loss):
nan_skipped += 1
opt_vae.zero_grad()
continue
# ── PatchGAN discriminator step ───────────────────────────────
adv_d = real.new_zeros(1).squeeze()
if use_adversarial:
opt_d.zero_grad()
d_real = patchgan(real)
d_fake = patchgan(recon.detach())
adv_d = hinge_d_loss(d_real, d_fake)
if torch.isfinite(adv_d):
adv_d.backward()
torch.nn.utils.clip_grad_norm_(patchgan.parameters(), grad_clip)
opt_d.step()
# ── PatchGAN generator adversarial loss ───────────────────────
adv_g = real.new_zeros(1).squeeze()
if use_adversarial:
adv_g = hinge_g_loss(patchgan(recon))
vae_loss = vae_loss + lambda_adversarial * adv_g
# ── VAE backward ──────────────────────────────────────────────
opt_vae.zero_grad()
vae_loss.backward()
torch.nn.utils.clip_grad_norm_(vae.parameters(), grad_clip)
opt_vae.step()
ema.update(vae)
recon_sum += mse.item()
kl_sum += kl.item()
perc_sum += perc.item()
adv_g_sum += adv_g.item()
adv_d_sum += adv_d.item()
n_batches += 1
avg_r = recon_sum / max(n_batches, 1)
avg_k = kl_sum / max(n_batches, 1)
avg_p = perc_sum / max(n_batches, 1)
avg_g = adv_g_sum / max(n_batches, 1)
avg_d = adv_d_sum / max(n_batches, 1)
history["recon_loss"].append(avg_r)
history["kl_loss"].append(avg_k)
history["perc_loss"].append(avg_p)
history["adv_g_loss"].append(avg_g)
history["adv_d_loss"].append(avg_d)
print(
f"[{epoch:03d}/{epochs}] "
f"MSE: {avg_r:.4f} KL: {avg_k:.2f} β={current_beta:.6f} "
f"Perc: {avg_p:.4f} AdvG: {avg_g:.4f} AdvD: {avg_d:.4f}"
f" (NaN skipped: {nan_skipped})"
)
if epoch % sample_interval == 0:
_save_vae_samples(
ema.model, samples_dir, epoch,
fixed_z=fixed_z, fixed_real=fixed_real, device=device,
)
if epoch % fid_interval == 0:
ema.model.eval()
with torch.no_grad():
fake_imgs = torch.cat([
ema.model.sample(64, device)
for _ in range(fid_n_real // 64 + 1)
])[:fid_n_real]
fid_score = fid_eval.compute(fake_imgs)
history["fid"][epoch] = fid_score
print(f" FID @ epoch {epoch}: {fid_score:.2f}")
if fid_score < best_fid:
best_fid = fid_score
torch.save(vae.state_dict(), save_dir / f"{run_name}_best_vae.pt")
torch.save(ema.model.state_dict(), save_dir / f"{run_name}_best_ema.pt")
sched_vae.step()
if sched_d is not None:
sched_d.step()
torch.save(vae.state_dict(), save_dir / f"{run_name}_final_vae.pt")
torch.save(ema.model.state_dict(), save_dir / f"{run_name}_final_ema.pt")
if patchgan is not None:
torch.save(patchgan.state_dict(), save_dir / f"{run_name}_final_patchgan.pt")
history["train_time_s"] = time.time() - t_start
print(f"Total NaN-skipped batches: {nan_skipped}")
return history
# ────────────────────────────────────────────────────────────────────────────
# Phase 4 — DDPM (4.1 linear·ε · 4.2 cosine·ε · 4.3 cosine·v · 4.4 wider)
# ────────────────────────────────────────────────────────────────────────────
def train_ddpm(
model,
train_dataset,
cfg: dict,
*,
save_dir,
run_name: str,
device: str = "cuda",
) -> dict:
"""DDPM training loop (Ho et al., 2020) covering Phase 4.1 4.4.
Config keys:
noise_schedule — "linear" (4.1) or "cosine" (4.2+)
pred_type — "eps" (4.14.2) or "v" (4.3+)
T — diffusion timesteps (default 1000)
base_ch / ch_mult / attn_resolutions — U-Net capacity (see unet.py)
ddim_steps — DDIM steps for FID evaluation (default 100)
"""
from src.training.diffusion import (
linear_betas, cosine_betas, make_alpha_bars,
diffusion_loss, ddim_sample,
)
device = torch.device(device if torch.cuda.is_available() else "cpu")
model = model.to(device)
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"U-Net: {n_params:,} params")
epochs = cfg["epochs"]
batch_size = cfg["batch_size"]
lr = cfg.get("lr", 2e-4)
T = cfg.get("T", 1000)
noise_schedule = cfg.get("noise_schedule", "linear")
pred_type = cfg.get("pred_type", "eps")
ddim_steps = cfg.get("ddim_steps", 100)
image_size = cfg.get("image_size", 64)
ema_decay = cfg.get("ema_decay", 0.9999)
sample_interval = cfg.get("sample_interval", 10)
fid_interval = cfg.get("fid_interval", 25)
fid_n_real = cfg.get("fid_n_real", 5000)
# Build noise schedule and register on device
betas = (cosine_betas(T) if noise_schedule == "cosine" else linear_betas(T)).to(device)
alpha_bars = make_alpha_bars(betas) # on device
loader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=True,
num_workers=cfg.get("num_workers", min(4, os.cpu_count() or 1)),
pin_memory=(device.type == "cuda"), drop_last=True,
)
opt = torch.optim.AdamW(model.parameters(), lr=lr)
use_amp = device.type == "cuda"
scaler = _GradScaler("cuda", enabled=use_amp)
ema = EMA(model, decay=ema_decay)
# Fixed noise for sample visualisation (same latents across epochs)
fixed_noise = torch.randn(16, 3, image_size, image_size, device=device)
save_dir = Path(save_dir)
save_dir.mkdir(parents=True, exist_ok=True)
samples_dir = save_dir.parent / "samples" / run_name
fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device),
num_workers=cfg.get("num_workers", min(4, os.cpu_count() or 1)))
history = {"loss": [], "fid": {}}
best_fid = float("inf")
print(
f"Device: {device} AMP: {use_amp} Batches/epoch: {len(loader)}"
f" T={T} schedule={noise_schedule} pred={pred_type} ddim_steps={ddim_steps}"
)
# Linear LR decay from epoch epochs//2 to epochs
decay_start = epochs // 2
sched = torch.optim.lr_scheduler.LambdaLR(
opt, lr_lambda=lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / max(epochs - decay_start, 1)))
t_start = time.time()
for epoch in range(1, epochs + 1):
model.train()
loss_sum = 0.0
n_batches = 0
for x0 in tqdm(loader, desc=f"Epoch {epoch}/{epochs}", leave=False):
x0 = x0.to(device)
t = torch.randint(0, T, (x0.size(0),), device=device)
with _autocast("cuda", enabled=use_amp):
loss = diffusion_loss(model, x0, t, alpha_bars, pred_type)
opt.zero_grad()
scaler.scale(loss).backward()
scaler.unscale_(opt)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
scaler.step(opt)
scaler.update()
ema.update(model)
loss_sum += loss.item()
n_batches += 1
avg_loss = loss_sum / n_batches
history["loss"].append(avg_loss)
print(f"[{epoch:03d}/{epochs}] Loss: {avg_loss:.5f}")
if epoch % sample_interval == 0:
samples_dir.mkdir(parents=True, exist_ok=True)
ema.model.eval()
with torch.no_grad():
# Quick visualisation: denoise fixed_noise via DDIM
imgs = ddim_sample(
ema.model, 16, image_size, alpha_bars,
n_steps=50, pred_type=pred_type, device=str(device), batch_size=16,
)
imgs = (imgs.clamp(-1, 1) + 1.0) / 2.0
save_image(imgs, samples_dir / f"epoch_{epoch:04d}.png", nrow=4)
if epoch % fid_interval == 0:
ema.model.eval()
fake_imgs = ddim_sample(
ema.model, fid_n_real, image_size, alpha_bars,
n_steps=ddim_steps, pred_type=pred_type,
device=str(device), batch_size=32,
)
fid_score = fid_eval.compute(fake_imgs)
history["fid"][epoch] = fid_score
print(f" FID @ epoch {epoch}: {fid_score:.2f}")
if fid_score < best_fid:
best_fid = fid_score
torch.save(model.state_dict(), save_dir / f"{run_name}_best_unet.pt")
torch.save(ema.model.state_dict(), save_dir / f"{run_name}_best_ema.pt")
sched.step()
torch.save(model.state_dict(), save_dir / f"{run_name}_final_unet.pt")
torch.save(ema.model.state_dict(), save_dir / f"{run_name}_final_ema.pt")
history["train_time_s"] = time.time() - t_start
return history