Preview of phase 2-5 implementation; needs a full check
This commit is contained in:
@@ -1,8 +1,10 @@
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision.utils import save_image
|
||||
from tqdm import tqdm
|
||||
@@ -19,12 +21,11 @@ else:
|
||||
_autocast = lambda device_type="", enabled=True, **kw: _AC(enabled=enabled, **kw)
|
||||
|
||||
|
||||
def _save_samples(generator_ema, samples_dir: Path, epoch: int, *, latent_dim: int, device) -> None:
|
||||
def _save_samples(generator_ema, samples_dir: Path, epoch: int, *, fixed_noise: torch.Tensor, device) -> None:
|
||||
samples_dir.mkdir(parents=True, exist_ok=True)
|
||||
with torch.no_grad():
|
||||
noise = torch.randn(16, latent_dim, 1, 1, device=device)
|
||||
imgs = generator_ema.model(noise) # EMA model, [-1, 1]
|
||||
imgs = (imgs.clamp(-1, 1) + 1.0) / 2.0 # -> [0, 1]
|
||||
imgs = generator_ema.model(fixed_noise.to(device)) # EMA model, [-1, 1]
|
||||
imgs = (imgs.clamp(-1, 1) + 1.0) / 2.0 # -> [0, 1]
|
||||
save_image(imgs, samples_dir / f"epoch_{epoch:04d}.png", nrow=4)
|
||||
|
||||
|
||||
@@ -78,6 +79,9 @@ def train_dcgan(
|
||||
|
||||
ema = EMA(generator, decay=ema_decay)
|
||||
|
||||
# Fixed noise for consistent sample tracking across epochs
|
||||
fixed_noise = torch.randn(16, latent_dim, 1, 1, device=device)
|
||||
|
||||
save_dir = Path(save_dir)
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
samples_dir = save_dir.parent / "samples" / run_name
|
||||
@@ -88,6 +92,15 @@ def train_dcgan(
|
||||
best_fid = float("inf")
|
||||
print(f"Device: {device} AMP: {use_amp} Batches/epoch: {len(loader)}")
|
||||
|
||||
# Linear LR decay from epoch epochs//2 to epochs
|
||||
decay_start = epochs // 2
|
||||
sched_g = torch.optim.lr_scheduler.LambdaLR(
|
||||
opt_g, lr_lambda=lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / (epochs - decay_start)))
|
||||
sched_d = torch.optim.lr_scheduler.LambdaLR(
|
||||
opt_d, lr_lambda=lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / (epochs - decay_start)))
|
||||
|
||||
t_start = time.time()
|
||||
|
||||
for epoch in range(1, epochs + 1):
|
||||
generator.train()
|
||||
discriminator.train()
|
||||
@@ -142,13 +155,13 @@ def train_dcgan(
|
||||
)
|
||||
|
||||
if epoch % sample_interval == 0:
|
||||
_save_samples(ema, samples_dir, epoch, latent_dim=latent_dim, device=device)
|
||||
_save_samples(ema, samples_dir, epoch, fixed_noise=fixed_noise, device=device)
|
||||
|
||||
if epoch % fid_interval == 0:
|
||||
generator.eval()
|
||||
ema.model.eval()
|
||||
with torch.no_grad():
|
||||
fake_imgs = torch.cat([
|
||||
generator(torch.randn(64, latent_dim, 1, 1, device=device))
|
||||
ema.model(torch.randn(64, latent_dim, 1, 1, device=device))
|
||||
for _ in range(fid_n_real // 64 + 1)
|
||||
])[:fid_n_real]
|
||||
fid_score = fid_eval.compute(fake_imgs)
|
||||
@@ -160,7 +173,586 @@ def train_dcgan(
|
||||
torch.save(generator.state_dict(), save_dir / f"{run_name}_best_g.pt")
|
||||
torch.save(ema.model.state_dict(), save_dir / f"{run_name}_best_ema.pt")
|
||||
|
||||
sched_g.step()
|
||||
sched_d.step()
|
||||
|
||||
torch.save(generator.state_dict(), save_dir / f"{run_name}_final_g.pt")
|
||||
torch.save(discriminator.state_dict(), save_dir / f"{run_name}_final_d.pt")
|
||||
torch.save(ema.model.state_dict(), save_dir / f"{run_name}_final_ema.pt")
|
||||
history["train_time_s"] = time.time() - t_start
|
||||
return history
|
||||
|
||||
|
||||
def _gradient_penalty(critic, real: torch.Tensor, fake: torch.Tensor, device) -> torch.Tensor:
|
||||
"""Two-sided gradient penalty (Gulrajani et al., 2017)."""
|
||||
bsz = real.size(0)
|
||||
eps = torch.rand(bsz, 1, 1, 1, device=device)
|
||||
interp = (eps * real + (1.0 - eps) * fake).requires_grad_(True)
|
||||
d_interp = critic(interp)
|
||||
grad = torch.autograd.grad(
|
||||
outputs=d_interp,
|
||||
inputs=interp,
|
||||
grad_outputs=torch.ones_like(d_interp),
|
||||
create_graph=True,
|
||||
retain_graph=True,
|
||||
)[0]
|
||||
return ((grad.norm(2, dim=[1, 2, 3]) - 1) ** 2).mean()
|
||||
|
||||
|
||||
def train_wgan(
|
||||
generator,
|
||||
critic,
|
||||
train_dataset,
|
||||
cfg: dict,
|
||||
*,
|
||||
save_dir,
|
||||
run_name: str,
|
||||
device: str = "cuda",
|
||||
) -> dict:
|
||||
"""WGAN-GP training loop (Gulrajani et al., 2017).
|
||||
|
||||
Used for Phase 2.2–2.4. Gradient penalty replaces weight clipping.
|
||||
The critic runs in float32 to keep GP gradient computation numerically
|
||||
stable; AMP is used only for the generator forward/backward.
|
||||
"""
|
||||
device = torch.device(device if torch.cuda.is_available() else "cpu")
|
||||
generator = generator.to(device)
|
||||
critic = critic.to(device)
|
||||
|
||||
n_g = sum(p.numel() for p in generator.parameters() if p.requires_grad)
|
||||
n_c = sum(p.numel() for p in critic.parameters() if p.requires_grad)
|
||||
print(f"Generator: {n_g:,} params Critic: {n_c:,} params")
|
||||
|
||||
epochs = cfg["epochs"]
|
||||
batch_size = cfg["batch_size"]
|
||||
lr_g = cfg.get("lr_g", 1e-4)
|
||||
lr_d = cfg.get("lr_d", 1e-4)
|
||||
beta1 = cfg.get("beta1", 0.0)
|
||||
beta2 = cfg.get("beta2", 0.9)
|
||||
latent_dim = cfg.get("latent_dim", 128)
|
||||
n_critic = cfg.get("n_critic", 5)
|
||||
gp_lambda = cfg.get("gp_lambda", 10)
|
||||
ema_decay = cfg.get("ema_decay", 0.9999)
|
||||
sample_interval = cfg.get("sample_interval", 10)
|
||||
fid_interval = cfg.get("fid_interval", 25)
|
||||
fid_n_real = cfg.get("fid_n_real", 5000)
|
||||
|
||||
loader = DataLoader(
|
||||
train_dataset, batch_size=batch_size, shuffle=True,
|
||||
num_workers=min(4, os.cpu_count() or 1),
|
||||
pin_memory=(device.type == "cuda"), drop_last=True,
|
||||
)
|
||||
opt_g = torch.optim.Adam(generator.parameters(), lr=lr_g, betas=(beta1, beta2))
|
||||
opt_c = torch.optim.Adam(critic.parameters(), lr=lr_d, betas=(beta1, beta2))
|
||||
|
||||
use_amp = device.type == "cuda"
|
||||
scaler_g = _GradScaler("cuda", enabled=use_amp)
|
||||
|
||||
ema = EMA(generator, decay=ema_decay)
|
||||
|
||||
# Fixed noise for consistent sample tracking across epochs
|
||||
fixed_noise = torch.randn(16, latent_dim, 1, 1, device=device)
|
||||
|
||||
save_dir = Path(save_dir)
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
samples_dir = save_dir.parent / "samples" / run_name
|
||||
|
||||
fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device))
|
||||
|
||||
history = {"g_loss": [], "w_dist": [], "gp": [], "d_real": [], "d_fake": [], "fid": {}}
|
||||
best_fid = float("inf")
|
||||
print(f"Device: {device} AMP (G only): {use_amp} Batches/epoch: {len(loader)} n_critic: {n_critic}")
|
||||
|
||||
# Linear LR decay from epoch epochs//2 to epochs
|
||||
decay_start = epochs // 2
|
||||
sched_g = torch.optim.lr_scheduler.LambdaLR(
|
||||
opt_g, lr_lambda=lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / (epochs - decay_start)))
|
||||
sched_c = torch.optim.lr_scheduler.LambdaLR(
|
||||
opt_c, lr_lambda=lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / (epochs - decay_start)))
|
||||
|
||||
t_start = time.time()
|
||||
|
||||
for epoch in range(1, epochs + 1):
|
||||
generator.train()
|
||||
critic.train()
|
||||
g_sum = w_sum = gp_sum = real_sum = fake_sum = 0.0
|
||||
n_c_steps = n_g_steps = 0
|
||||
|
||||
for batch_idx, real in enumerate(tqdm(loader, desc=f"Epoch {epoch}/{epochs}", leave=False)):
|
||||
real = real.to(device)
|
||||
bsz = real.size(0)
|
||||
|
||||
# ── Critic step (every batch) ─────────────────────────────────
|
||||
# Run critic in float32 — GP requires double-precision gradients
|
||||
# and AMP can degrade stability here.
|
||||
opt_c.zero_grad()
|
||||
with torch.no_grad():
|
||||
fake = generator(torch.randn(bsz, latent_dim, 1, 1, device=device))
|
||||
|
||||
real_f32 = real.float()
|
||||
fake_f32 = fake.float().detach()
|
||||
|
||||
d_real = critic(real_f32)
|
||||
d_fake = critic(fake_f32)
|
||||
gp = _gradient_penalty(critic, real_f32, fake_f32.detach(), device)
|
||||
c_loss = d_fake.mean() - d_real.mean() + gp_lambda * gp
|
||||
c_loss.backward()
|
||||
opt_c.step()
|
||||
|
||||
w_dist = (d_real.mean() - d_fake.mean()).item()
|
||||
w_sum += w_dist
|
||||
gp_sum += gp.item()
|
||||
real_sum += d_real.mean().item()
|
||||
fake_sum += d_fake.mean().item()
|
||||
n_c_steps += 1
|
||||
|
||||
# ── Generator step (every n_critic batches) ───────────────────
|
||||
if (batch_idx + 1) % n_critic == 0:
|
||||
opt_g.zero_grad()
|
||||
with _autocast("cuda", enabled=use_amp):
|
||||
fake = generator(torch.randn(bsz, latent_dim, 1, 1, device=device))
|
||||
g_loss = -critic(fake.float()).mean()
|
||||
scaler_g.scale(g_loss).backward()
|
||||
scaler_g.step(opt_g)
|
||||
scaler_g.update()
|
||||
ema.update(generator)
|
||||
g_sum += g_loss.item()
|
||||
n_g_steps += 1
|
||||
|
||||
avg_w = w_sum / max(n_c_steps, 1)
|
||||
avg_gp = gp_sum / max(n_c_steps, 1)
|
||||
avg_g = g_sum / max(n_g_steps, 1)
|
||||
avg_r = real_sum / max(n_c_steps, 1)
|
||||
avg_f = fake_sum / max(n_c_steps, 1)
|
||||
history["g_loss"].append(avg_g)
|
||||
history["w_dist"].append(avg_w)
|
||||
history["gp"].append(avg_gp)
|
||||
history["d_real"].append(avg_r)
|
||||
history["d_fake"].append(avg_f)
|
||||
print(
|
||||
f"[{epoch:03d}/{epochs}] "
|
||||
f"G: {avg_g:.4f} W-dist: {avg_w:.4f} GP: {avg_gp:.4f} "
|
||||
f"C(real): {avg_r:.4f} C(fake): {avg_f:.4f}"
|
||||
)
|
||||
|
||||
if epoch % sample_interval == 0:
|
||||
_save_samples(ema, samples_dir, epoch, fixed_noise=fixed_noise, device=device)
|
||||
|
||||
if epoch % fid_interval == 0:
|
||||
ema.model.eval()
|
||||
with torch.no_grad():
|
||||
fake_imgs = torch.cat([
|
||||
ema.model(torch.randn(64, latent_dim, 1, 1, device=device))
|
||||
for _ in range(fid_n_real // 64 + 1)
|
||||
])[:fid_n_real]
|
||||
fid_score = fid_eval.compute(fake_imgs)
|
||||
history["fid"][epoch] = fid_score
|
||||
print(f" FID @ epoch {epoch}: {fid_score:.2f}")
|
||||
|
||||
if fid_score < best_fid:
|
||||
best_fid = fid_score
|
||||
torch.save(generator.state_dict(), save_dir / f"{run_name}_best_g.pt")
|
||||
torch.save(ema.model.state_dict(), save_dir / f"{run_name}_best_ema.pt")
|
||||
|
||||
sched_g.step()
|
||||
sched_c.step()
|
||||
|
||||
torch.save(generator.state_dict(), save_dir / f"{run_name}_final_g.pt")
|
||||
torch.save(critic.state_dict(), save_dir / f"{run_name}_final_d.pt")
|
||||
torch.save(ema.model.state_dict(), save_dir / f"{run_name}_final_ema.pt")
|
||||
history["train_time_s"] = time.time() - t_start
|
||||
return history
|
||||
|
||||
|
||||
# ────────────────────────────────────────────────────────────────────────────
|
||||
# Phase 3 — VAE (3.1 MSE+KL · 3.2 +Perceptual · 3.3 +PatchGAN)
|
||||
# ────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
def _save_vae_samples(
|
||||
vae,
|
||||
samples_dir: Path,
|
||||
epoch: int,
|
||||
*,
|
||||
fixed_z: torch.Tensor,
|
||||
fixed_real: torch.Tensor,
|
||||
device,
|
||||
) -> None:
|
||||
"""Save prior samples and a real-vs-reconstruction grid side by side."""
|
||||
samples_dir.mkdir(parents=True, exist_ok=True)
|
||||
vae.eval()
|
||||
with torch.no_grad():
|
||||
prior = vae.decode(fixed_z.to(device))
|
||||
prior = (prior.clamp(-1, 1) + 1.0) / 2.0
|
||||
save_image(prior, samples_dir / f"epoch_{epoch:04d}.png", nrow=4)
|
||||
|
||||
recon, _, _ = vae(fixed_real.to(device))
|
||||
recon = (recon.clamp(-1, 1) + 1.0) / 2.0
|
||||
real = (fixed_real.to(device) + 1.0) / 2.0
|
||||
# Interleave real / reconstruction pairs
|
||||
pairs = torch.stack([real, recon], dim=1).flatten(0, 1)
|
||||
save_image(pairs, samples_dir / f"epoch_{epoch:04d}_recon.png", nrow=4)
|
||||
vae.train()
|
||||
|
||||
|
||||
def train_vae(
|
||||
vae,
|
||||
train_dataset,
|
||||
cfg: dict,
|
||||
*,
|
||||
save_dir,
|
||||
run_name: str,
|
||||
device: str = "cuda",
|
||||
) -> dict:
|
||||
"""VAE training loop covering Phase 3.1 – 3.3.
|
||||
|
||||
Config toggles:
|
||||
lambda_perceptual > 0 → VGG-16 perceptual loss (Phase 3.2+)
|
||||
lambda_adversarial > 0 → PatchGAN hinge adversarial loss (Phase 3.3)
|
||||
|
||||
Loss: L = L_mse + λ_perc·L_vgg + λ_adv·L_adv + β_kl·L_kl
|
||||
"""
|
||||
device = torch.device(device if torch.cuda.is_available() else "cpu")
|
||||
vae = vae.to(device)
|
||||
|
||||
n_vae = sum(p.numel() for p in vae.parameters() if p.requires_grad)
|
||||
print(f"VAE: {n_vae:,} params")
|
||||
|
||||
epochs = cfg["epochs"]
|
||||
batch_size = cfg["batch_size"]
|
||||
lr = cfg.get("lr", 1e-3)
|
||||
latent_dim = cfg.get("latent_dim", 256)
|
||||
beta_kl = cfg.get("beta_kl", 1.0)
|
||||
lambda_perceptual = cfg.get("lambda_perceptual", 0.0)
|
||||
lambda_adversarial = cfg.get("lambda_adversarial", 0.0)
|
||||
lr_d = cfg.get("lr_d", 1e-4)
|
||||
ema_decay = cfg.get("ema_decay", 0.9999)
|
||||
sample_interval = cfg.get("sample_interval", 10)
|
||||
fid_interval = cfg.get("fid_interval", 25)
|
||||
fid_n_real = cfg.get("fid_n_real", 5000)
|
||||
|
||||
use_perceptual = lambda_perceptual > 0
|
||||
use_adversarial = lambda_adversarial > 0
|
||||
|
||||
loader = DataLoader(
|
||||
train_dataset, batch_size=batch_size, shuffle=True,
|
||||
num_workers=min(4, os.cpu_count() or 1),
|
||||
pin_memory=(device.type == "cuda"), drop_last=True,
|
||||
)
|
||||
|
||||
opt_vae = torch.optim.Adam(vae.parameters(), lr=lr)
|
||||
use_amp = device.type == "cuda"
|
||||
scaler = _GradScaler("cuda", enabled=use_amp)
|
||||
|
||||
# KL warmup: linearly ramp beta_kl from 0 to target over first 20% of training
|
||||
kl_warmup_epochs = max(1, epochs // 5)
|
||||
|
||||
# Linear LR decay from epoch epochs//2 to epochs
|
||||
decay_start = epochs // 2
|
||||
sched_vae = torch.optim.lr_scheduler.LambdaLR(
|
||||
opt_vae, lr_lambda=lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / max(epochs - decay_start, 1)))
|
||||
sched_d = None # set below if adversarial
|
||||
|
||||
# ── Optional components ───────────────────────────────────────────────
|
||||
perc_fn = None
|
||||
patchgan = None
|
||||
opt_d = None
|
||||
scaler_d = None
|
||||
|
||||
if use_perceptual:
|
||||
from src.training.perceptual import PerceptualLoss
|
||||
perc_fn = PerceptualLoss().to(device)
|
||||
print("Perceptual loss: VGG-16 relu1_2 + relu2_2 + relu3_3")
|
||||
|
||||
if use_adversarial:
|
||||
from src.models.patchgan import PatchGANDiscriminator, hinge_d_loss, hinge_g_loss
|
||||
patchgan = PatchGANDiscriminator(
|
||||
ndf=cfg.get("ndf_patch", 64),
|
||||
image_size=cfg.get("image_size", 64),
|
||||
).to(device)
|
||||
opt_d = torch.optim.Adam(patchgan.parameters(), lr=lr_d, betas=(0.5, 0.999))
|
||||
scaler_d = _GradScaler("cuda", enabled=use_amp)
|
||||
sched_d = torch.optim.lr_scheduler.LambdaLR(
|
||||
opt_d, lr_lambda=lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / max(epochs - decay_start, 1)))
|
||||
n_d = sum(p.numel() for p in patchgan.parameters())
|
||||
print(f"PatchGAN: {n_d:,} params")
|
||||
else:
|
||||
hinge_d_loss = hinge_g_loss = None # satisfy linter, never called
|
||||
|
||||
# ── Fixed seeds for consistent visualisation ──────────────────────────
|
||||
fixed_z = torch.randn(16, latent_dim, device=device)
|
||||
# Grab first 16 real images from the loader for reconstruction tracking
|
||||
_it = iter(loader)
|
||||
fixed_real = next(_it)[:16].cpu()
|
||||
|
||||
ema = EMA(vae, decay=ema_decay)
|
||||
|
||||
save_dir = Path(save_dir)
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
samples_dir = save_dir.parent / "samples" / run_name
|
||||
|
||||
fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device))
|
||||
|
||||
history = {
|
||||
"recon_loss": [], "kl_loss": [], "perc_loss": [],
|
||||
"adv_g_loss": [], "adv_d_loss": [], "fid": {},
|
||||
}
|
||||
best_fid = float("inf")
|
||||
print(
|
||||
f"Device: {device} AMP: {use_amp} Batches/epoch: {len(loader)}"
|
||||
f" β_kl={beta_kl} (warmup {kl_warmup_epochs}ep) λ_perc={lambda_perceptual} λ_adv={lambda_adversarial}"
|
||||
)
|
||||
|
||||
t_start = time.time()
|
||||
|
||||
for epoch in range(1, epochs + 1):
|
||||
vae.train()
|
||||
if patchgan is not None:
|
||||
patchgan.train()
|
||||
|
||||
recon_sum = kl_sum = perc_sum = adv_g_sum = adv_d_sum = 0.0
|
||||
n_batches = 0
|
||||
|
||||
for real in tqdm(loader, desc=f"Epoch {epoch}/{epochs}", leave=False):
|
||||
real = real.to(device)
|
||||
|
||||
# KL warmup: ramp from 0 to beta_kl over kl_warmup_epochs
|
||||
current_beta = beta_kl * min(1.0, epoch / kl_warmup_epochs)
|
||||
|
||||
# ── VAE forward ───────────────────────────────────────────────
|
||||
with _autocast("cuda", enabled=use_amp):
|
||||
recon, mu, log_var = vae(real)
|
||||
mse = F.mse_loss(recon, real)
|
||||
kl = -0.5 * (1 + log_var - mu.pow(2) - log_var.exp()).sum(1).mean()
|
||||
perc = perc_fn(recon, real) if use_perceptual else real.new_zeros(1).squeeze()
|
||||
vae_loss = mse + current_beta * kl + lambda_perceptual * perc
|
||||
|
||||
# ── PatchGAN discriminator step ───────────────────────────────
|
||||
adv_d = real.new_zeros(1).squeeze()
|
||||
if use_adversarial:
|
||||
opt_d.zero_grad()
|
||||
with _autocast("cuda", enabled=use_amp):
|
||||
d_real = patchgan(real)
|
||||
d_fake = patchgan(recon.detach())
|
||||
adv_d = hinge_d_loss(d_real, d_fake)
|
||||
scaler_d.scale(adv_d).backward()
|
||||
scaler_d.step(opt_d)
|
||||
scaler_d.update()
|
||||
|
||||
# ── PatchGAN generator adversarial loss ───────────────────────
|
||||
adv_g = real.new_zeros(1).squeeze()
|
||||
if use_adversarial:
|
||||
with _autocast("cuda", enabled=use_amp):
|
||||
adv_g = hinge_g_loss(patchgan(recon))
|
||||
vae_loss = vae_loss + lambda_adversarial * adv_g
|
||||
|
||||
# ── VAE backward ──────────────────────────────────────────────
|
||||
opt_vae.zero_grad()
|
||||
scaler.scale(vae_loss).backward()
|
||||
scaler.step(opt_vae)
|
||||
scaler.update()
|
||||
ema.update(vae)
|
||||
|
||||
recon_sum += mse.item()
|
||||
kl_sum += kl.item()
|
||||
perc_sum += perc.item()
|
||||
adv_g_sum += adv_g.item()
|
||||
adv_d_sum += adv_d.item()
|
||||
n_batches += 1
|
||||
|
||||
avg_r = recon_sum / n_batches
|
||||
avg_k = kl_sum / n_batches
|
||||
avg_p = perc_sum / n_batches
|
||||
avg_g = adv_g_sum / n_batches
|
||||
avg_d = adv_d_sum / n_batches
|
||||
history["recon_loss"].append(avg_r)
|
||||
history["kl_loss"].append(avg_k)
|
||||
history["perc_loss"].append(avg_p)
|
||||
history["adv_g_loss"].append(avg_g)
|
||||
history["adv_d_loss"].append(avg_d)
|
||||
|
||||
print(
|
||||
f"[{epoch:03d}/{epochs}] "
|
||||
f"MSE: {avg_r:.4f} KL: {avg_k:.2f} β={current_beta:.6f} "
|
||||
f"Perc: {avg_p:.4f} AdvG: {avg_g:.4f} AdvD: {avg_d:.4f}"
|
||||
)
|
||||
|
||||
if epoch % sample_interval == 0:
|
||||
_save_vae_samples(
|
||||
ema.model, samples_dir, epoch,
|
||||
fixed_z=fixed_z, fixed_real=fixed_real, device=device,
|
||||
)
|
||||
|
||||
if epoch % fid_interval == 0:
|
||||
ema.model.eval()
|
||||
with torch.no_grad():
|
||||
fake_imgs = torch.cat([
|
||||
ema.model.sample(64, device)
|
||||
for _ in range(fid_n_real // 64 + 1)
|
||||
])[:fid_n_real]
|
||||
fid_score = fid_eval.compute(fake_imgs)
|
||||
history["fid"][epoch] = fid_score
|
||||
print(f" FID @ epoch {epoch}: {fid_score:.2f}")
|
||||
|
||||
if fid_score < best_fid:
|
||||
best_fid = fid_score
|
||||
torch.save(vae.state_dict(), save_dir / f"{run_name}_best_vae.pt")
|
||||
torch.save(ema.model.state_dict(), save_dir / f"{run_name}_best_ema.pt")
|
||||
|
||||
sched_vae.step()
|
||||
if sched_d is not None:
|
||||
sched_d.step()
|
||||
|
||||
torch.save(vae.state_dict(), save_dir / f"{run_name}_final_vae.pt")
|
||||
torch.save(ema.model.state_dict(), save_dir / f"{run_name}_final_ema.pt")
|
||||
if patchgan is not None:
|
||||
torch.save(patchgan.state_dict(), save_dir / f"{run_name}_final_patchgan.pt")
|
||||
history["train_time_s"] = time.time() - t_start
|
||||
return history
|
||||
|
||||
|
||||
# ────────────────────────────────────────────────────────────────────────────
|
||||
# Phase 4 — DDPM (4.1 linear·ε · 4.2 cosine·ε · 4.3 cosine·v · 4.4 wider)
|
||||
# ────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
def train_ddpm(
|
||||
model,
|
||||
train_dataset,
|
||||
cfg: dict,
|
||||
*,
|
||||
save_dir,
|
||||
run_name: str,
|
||||
device: str = "cuda",
|
||||
) -> dict:
|
||||
"""DDPM training loop (Ho et al., 2020) covering Phase 4.1 – 4.4.
|
||||
|
||||
Config keys:
|
||||
noise_schedule — "linear" (4.1) or "cosine" (4.2+)
|
||||
pred_type — "eps" (4.1–4.2) or "v" (4.3+)
|
||||
T — diffusion timesteps (default 1000)
|
||||
base_ch / ch_mult / attn_resolutions — U-Net capacity (see unet.py)
|
||||
ddim_steps — DDIM steps for FID evaluation (default 100)
|
||||
"""
|
||||
from src.training.diffusion import (
|
||||
linear_betas, cosine_betas, make_alpha_bars,
|
||||
diffusion_loss, ddim_sample,
|
||||
)
|
||||
|
||||
device = torch.device(device if torch.cuda.is_available() else "cpu")
|
||||
model = model.to(device)
|
||||
|
||||
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
print(f"U-Net: {n_params:,} params")
|
||||
|
||||
epochs = cfg["epochs"]
|
||||
batch_size = cfg["batch_size"]
|
||||
lr = cfg.get("lr", 2e-4)
|
||||
T = cfg.get("T", 1000)
|
||||
noise_schedule = cfg.get("noise_schedule", "linear")
|
||||
pred_type = cfg.get("pred_type", "eps")
|
||||
ddim_steps = cfg.get("ddim_steps", 100)
|
||||
image_size = cfg.get("image_size", 64)
|
||||
ema_decay = cfg.get("ema_decay", 0.9999)
|
||||
sample_interval = cfg.get("sample_interval", 10)
|
||||
fid_interval = cfg.get("fid_interval", 25)
|
||||
fid_n_real = cfg.get("fid_n_real", 5000)
|
||||
|
||||
# Build noise schedule and register on device
|
||||
betas = (cosine_betas(T) if noise_schedule == "cosine" else linear_betas(T)).to(device)
|
||||
alpha_bars = make_alpha_bars(betas) # on device
|
||||
|
||||
loader = DataLoader(
|
||||
train_dataset, batch_size=batch_size, shuffle=True,
|
||||
num_workers=min(4, os.cpu_count() or 1),
|
||||
pin_memory=(device.type == "cuda"), drop_last=True,
|
||||
)
|
||||
opt = torch.optim.AdamW(model.parameters(), lr=lr)
|
||||
|
||||
use_amp = device.type == "cuda"
|
||||
scaler = _GradScaler("cuda", enabled=use_amp)
|
||||
|
||||
ema = EMA(model, decay=ema_decay)
|
||||
|
||||
# Fixed noise for sample visualisation (same latents across epochs)
|
||||
fixed_noise = torch.randn(16, 3, image_size, image_size, device=device)
|
||||
|
||||
save_dir = Path(save_dir)
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
samples_dir = save_dir.parent / "samples" / run_name
|
||||
|
||||
fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device))
|
||||
|
||||
history = {"loss": [], "fid": {}}
|
||||
best_fid = float("inf")
|
||||
print(
|
||||
f"Device: {device} AMP: {use_amp} Batches/epoch: {len(loader)}"
|
||||
f" T={T} schedule={noise_schedule} pred={pred_type} ddim_steps={ddim_steps}"
|
||||
)
|
||||
|
||||
# Linear LR decay from epoch epochs//2 to epochs
|
||||
decay_start = epochs // 2
|
||||
sched = torch.optim.lr_scheduler.LambdaLR(
|
||||
opt, lr_lambda=lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / max(epochs - decay_start, 1)))
|
||||
|
||||
t_start = time.time()
|
||||
|
||||
for epoch in range(1, epochs + 1):
|
||||
model.train()
|
||||
loss_sum = 0.0
|
||||
n_batches = 0
|
||||
|
||||
for x0 in tqdm(loader, desc=f"Epoch {epoch}/{epochs}", leave=False):
|
||||
x0 = x0.to(device)
|
||||
t = torch.randint(0, T, (x0.size(0),), device=device)
|
||||
|
||||
with _autocast("cuda", enabled=use_amp):
|
||||
loss = diffusion_loss(model, x0, t, alpha_bars, pred_type)
|
||||
|
||||
opt.zero_grad()
|
||||
scaler.scale(loss).backward()
|
||||
scaler.unscale_(opt)
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
||||
scaler.step(opt)
|
||||
scaler.update()
|
||||
ema.update(model)
|
||||
|
||||
loss_sum += loss.item()
|
||||
n_batches += 1
|
||||
|
||||
avg_loss = loss_sum / n_batches
|
||||
history["loss"].append(avg_loss)
|
||||
print(f"[{epoch:03d}/{epochs}] Loss: {avg_loss:.5f}")
|
||||
|
||||
if epoch % sample_interval == 0:
|
||||
samples_dir.mkdir(parents=True, exist_ok=True)
|
||||
ema.model.eval()
|
||||
with torch.no_grad():
|
||||
# Quick visualisation: denoise fixed_noise via DDIM
|
||||
imgs = ddim_sample(
|
||||
ema.model, 16, image_size, alpha_bars,
|
||||
n_steps=50, pred_type=pred_type, device=str(device), batch_size=16,
|
||||
)
|
||||
imgs = (imgs.clamp(-1, 1) + 1.0) / 2.0
|
||||
save_image(imgs, samples_dir / f"epoch_{epoch:04d}.png", nrow=4)
|
||||
|
||||
if epoch % fid_interval == 0:
|
||||
ema.model.eval()
|
||||
fake_imgs = ddim_sample(
|
||||
ema.model, fid_n_real, image_size, alpha_bars,
|
||||
n_steps=ddim_steps, pred_type=pred_type,
|
||||
device=str(device), batch_size=32,
|
||||
)
|
||||
fid_score = fid_eval.compute(fake_imgs)
|
||||
history["fid"][epoch] = fid_score
|
||||
print(f" FID @ epoch {epoch}: {fid_score:.2f}")
|
||||
|
||||
if fid_score < best_fid:
|
||||
best_fid = fid_score
|
||||
torch.save(model.state_dict(), save_dir / f"{run_name}_best_unet.pt")
|
||||
torch.save(ema.model.state_dict(), save_dir / f"{run_name}_best_ema.pt")
|
||||
|
||||
sched.step()
|
||||
|
||||
torch.save(model.state_dict(), save_dir / f"{run_name}_final_unet.pt")
|
||||
torch.save(ema.model.state_dict(), save_dir / f"{run_name}_final_ema.pt")
|
||||
history["train_time_s"] = time.time() - t_start
|
||||
return history
|
||||
|
||||
Reference in New Issue
Block a user