Clean state
This commit is contained in:
@@ -0,0 +1,166 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision.utils import save_image
|
||||
from tqdm import tqdm
|
||||
|
||||
from src.training.ema import EMA
|
||||
from src.training.fid import FIDEvaluator
|
||||
|
||||
if hasattr(torch.amp, "GradScaler"):
|
||||
_GradScaler = torch.amp.GradScaler
|
||||
_autocast = torch.amp.autocast
|
||||
else:
|
||||
from torch.cuda.amp import GradScaler as _GS, autocast as _AC
|
||||
_GradScaler = lambda device="", enabled=True, **kw: _GS(**kw)
|
||||
_autocast = lambda device_type="", enabled=True, **kw: _AC(enabled=enabled, **kw)
|
||||
|
||||
|
||||
def _save_samples(generator_ema, samples_dir: Path, epoch: int, *, latent_dim: int, device) -> None:
|
||||
samples_dir.mkdir(parents=True, exist_ok=True)
|
||||
with torch.no_grad():
|
||||
noise = torch.randn(16, latent_dim, 1, 1, device=device)
|
||||
imgs = generator_ema.model(noise) # EMA model, [-1, 1]
|
||||
imgs = (imgs.clamp(-1, 1) + 1.0) / 2.0 # -> [0, 1]
|
||||
save_image(imgs, samples_dir / f"epoch_{epoch:04d}.png", nrow=4)
|
||||
|
||||
|
||||
def train_dcgan(
|
||||
generator,
|
||||
discriminator,
|
||||
train_dataset,
|
||||
cfg: dict,
|
||||
*,
|
||||
save_dir,
|
||||
run_name: str,
|
||||
device: str = "cuda",
|
||||
) -> dict:
|
||||
"""Vanilla DCGAN training loop with BCE loss (Radford et al., 2015).
|
||||
|
||||
Used as the Phase 1 baseline for cheap pipeline ablations. No gradient
|
||||
penalty, no n_critic, single G/D step per batch.
|
||||
"""
|
||||
device = torch.device(device if torch.cuda.is_available() else "cpu")
|
||||
generator = generator.to(device)
|
||||
discriminator = discriminator.to(device)
|
||||
|
||||
n_g = sum(p.numel() for p in generator.parameters() if p.requires_grad)
|
||||
n_d = sum(p.numel() for p in discriminator.parameters() if p.requires_grad)
|
||||
print(f"Generator: {n_g:,} params Discriminator: {n_d:,} params")
|
||||
|
||||
epochs = cfg["epochs"]
|
||||
batch_size = cfg["batch_size"]
|
||||
lr_g = cfg.get("lr_g", 2e-4)
|
||||
lr_d = cfg.get("lr_d", 2e-4)
|
||||
beta1 = cfg.get("beta1", 0.5)
|
||||
beta2 = cfg.get("beta2", 0.999)
|
||||
latent_dim = cfg.get("latent_dim", 100)
|
||||
ema_decay = cfg.get("ema_decay", 0.9999)
|
||||
sample_interval = cfg.get("sample_interval", 10)
|
||||
fid_interval = cfg.get("fid_interval", 25)
|
||||
fid_n_real = cfg.get("fid_n_real", 5000)
|
||||
|
||||
loader = DataLoader(
|
||||
train_dataset, batch_size=batch_size, shuffle=True,
|
||||
num_workers=min(4, os.cpu_count() or 1),
|
||||
pin_memory=(device.type == "cuda"), drop_last=True,
|
||||
)
|
||||
opt_g = torch.optim.Adam(generator.parameters(), lr=lr_g, betas=(beta1, beta2))
|
||||
opt_d = torch.optim.Adam(discriminator.parameters(), lr=lr_d, betas=(beta1, beta2))
|
||||
bce = nn.BCEWithLogitsLoss()
|
||||
|
||||
use_amp = device.type == "cuda"
|
||||
scaler_g = _GradScaler("cuda", enabled=use_amp)
|
||||
scaler_d = _GradScaler("cuda", enabled=use_amp)
|
||||
|
||||
ema = EMA(generator, decay=ema_decay)
|
||||
|
||||
save_dir = Path(save_dir)
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
samples_dir = save_dir.parent / "samples" / run_name
|
||||
|
||||
fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device))
|
||||
|
||||
history = {"g_loss": [], "d_loss": [], "d_real": [], "d_fake": [], "fid": {}}
|
||||
best_fid = float("inf")
|
||||
print(f"Device: {device} AMP: {use_amp} Batches/epoch: {len(loader)}")
|
||||
|
||||
for epoch in range(1, epochs + 1):
|
||||
generator.train()
|
||||
discriminator.train()
|
||||
g_sum = d_sum = real_sum = fake_sum = 0.0
|
||||
n_batches = 0
|
||||
|
||||
for imgs in tqdm(loader, desc=f"Epoch {epoch}/{epochs}", leave=False):
|
||||
imgs = imgs.to(device)
|
||||
bsz = imgs.size(0)
|
||||
real_labels = torch.ones(bsz, device=device)
|
||||
fake_labels = torch.zeros(bsz, device=device)
|
||||
|
||||
# ── Discriminator step ────────────────────────────────────────
|
||||
noise = torch.randn(bsz, latent_dim, 1, 1, device=device)
|
||||
with _autocast("cuda", enabled=use_amp):
|
||||
fake = generator(noise).detach()
|
||||
d_real = discriminator(imgs)
|
||||
d_fake = discriminator(fake)
|
||||
d_loss = bce(d_real, real_labels) + bce(d_fake, fake_labels)
|
||||
opt_d.zero_grad()
|
||||
scaler_d.scale(d_loss).backward()
|
||||
scaler_d.step(opt_d)
|
||||
scaler_d.update()
|
||||
|
||||
# ── Generator step ────────────────────────────────────────────
|
||||
noise = torch.randn(bsz, latent_dim, 1, 1, device=device)
|
||||
with _autocast("cuda", enabled=use_amp):
|
||||
g_loss = bce(discriminator(generator(noise)), real_labels)
|
||||
opt_g.zero_grad()
|
||||
scaler_g.scale(g_loss).backward()
|
||||
scaler_g.step(opt_g)
|
||||
scaler_g.update()
|
||||
ema.update(generator)
|
||||
|
||||
g_sum += g_loss.item()
|
||||
d_sum += d_loss.item()
|
||||
real_sum += d_real.mean().item()
|
||||
fake_sum += d_fake.mean().item()
|
||||
n_batches += 1
|
||||
|
||||
avg_g = g_sum / n_batches
|
||||
avg_d = d_sum / n_batches
|
||||
avg_r = real_sum / n_batches
|
||||
avg_f = fake_sum / n_batches
|
||||
history["g_loss"].append(avg_g)
|
||||
history["d_loss"].append(avg_d)
|
||||
history["d_real"].append(avg_r)
|
||||
history["d_fake"].append(avg_f)
|
||||
print(
|
||||
f"[{epoch:03d}/{epochs}] "
|
||||
f"G: {avg_g:.4f} D: {avg_d:.4f} D(real): {avg_r:.4f} D(fake): {avg_f:.4f}"
|
||||
)
|
||||
|
||||
if epoch % sample_interval == 0:
|
||||
_save_samples(ema, samples_dir, epoch, latent_dim=latent_dim, device=device)
|
||||
|
||||
if epoch % fid_interval == 0:
|
||||
generator.eval()
|
||||
with torch.no_grad():
|
||||
fake_imgs = torch.cat([
|
||||
generator(torch.randn(64, latent_dim, 1, 1, device=device))
|
||||
for _ in range(fid_n_real // 64 + 1)
|
||||
])[:fid_n_real]
|
||||
fid_score = fid_eval.compute(fake_imgs)
|
||||
history["fid"][epoch] = fid_score
|
||||
print(f" FID @ epoch {epoch}: {fid_score:.2f}")
|
||||
|
||||
if fid_score < best_fid:
|
||||
best_fid = fid_score
|
||||
torch.save(generator.state_dict(), save_dir / f"{run_name}_best_g.pt")
|
||||
torch.save(ema.model.state_dict(), save_dir / f"{run_name}_best_ema.pt")
|
||||
|
||||
torch.save(generator.state_dict(), save_dir / f"{run_name}_final_g.pt")
|
||||
torch.save(discriminator.state_dict(), save_dir / f"{run_name}_final_d.pt")
|
||||
torch.save(ema.model.state_dict(), save_dir / f"{run_name}_final_ema.pt")
|
||||
return history
|
||||
Reference in New Issue
Block a user