Clean state

This commit is contained in:
Johnny Fernandes
2026-04-30 01:25:39 +01:00
commit bb3dfb92d5
266 changed files with 37043 additions and 0 deletions
+166
View File
@@ -0,0 +1,166 @@
import os
from pathlib import Path
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from tqdm import tqdm
from src.training.ema import EMA
from src.training.fid import FIDEvaluator
if hasattr(torch.amp, "GradScaler"):
_GradScaler = torch.amp.GradScaler
_autocast = torch.amp.autocast
else:
from torch.cuda.amp import GradScaler as _GS, autocast as _AC
_GradScaler = lambda device="", enabled=True, **kw: _GS(**kw)
_autocast = lambda device_type="", enabled=True, **kw: _AC(enabled=enabled, **kw)
def _save_samples(generator_ema, samples_dir: Path, epoch: int, *, latent_dim: int, device) -> None:
samples_dir.mkdir(parents=True, exist_ok=True)
with torch.no_grad():
noise = torch.randn(16, latent_dim, 1, 1, device=device)
imgs = generator_ema.model(noise) # EMA model, [-1, 1]
imgs = (imgs.clamp(-1, 1) + 1.0) / 2.0 # -> [0, 1]
save_image(imgs, samples_dir / f"epoch_{epoch:04d}.png", nrow=4)
def train_dcgan(
generator,
discriminator,
train_dataset,
cfg: dict,
*,
save_dir,
run_name: str,
device: str = "cuda",
) -> dict:
"""Vanilla DCGAN training loop with BCE loss (Radford et al., 2015).
Used as the Phase 1 baseline for cheap pipeline ablations. No gradient
penalty, no n_critic, single G/D step per batch.
"""
device = torch.device(device if torch.cuda.is_available() else "cpu")
generator = generator.to(device)
discriminator = discriminator.to(device)
n_g = sum(p.numel() for p in generator.parameters() if p.requires_grad)
n_d = sum(p.numel() for p in discriminator.parameters() if p.requires_grad)
print(f"Generator: {n_g:,} params Discriminator: {n_d:,} params")
epochs = cfg["epochs"]
batch_size = cfg["batch_size"]
lr_g = cfg.get("lr_g", 2e-4)
lr_d = cfg.get("lr_d", 2e-4)
beta1 = cfg.get("beta1", 0.5)
beta2 = cfg.get("beta2", 0.999)
latent_dim = cfg.get("latent_dim", 100)
ema_decay = cfg.get("ema_decay", 0.9999)
sample_interval = cfg.get("sample_interval", 10)
fid_interval = cfg.get("fid_interval", 25)
fid_n_real = cfg.get("fid_n_real", 5000)
loader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=True,
num_workers=min(4, os.cpu_count() or 1),
pin_memory=(device.type == "cuda"), drop_last=True,
)
opt_g = torch.optim.Adam(generator.parameters(), lr=lr_g, betas=(beta1, beta2))
opt_d = torch.optim.Adam(discriminator.parameters(), lr=lr_d, betas=(beta1, beta2))
bce = nn.BCEWithLogitsLoss()
use_amp = device.type == "cuda"
scaler_g = _GradScaler("cuda", enabled=use_amp)
scaler_d = _GradScaler("cuda", enabled=use_amp)
ema = EMA(generator, decay=ema_decay)
save_dir = Path(save_dir)
save_dir.mkdir(parents=True, exist_ok=True)
samples_dir = save_dir.parent / "samples" / run_name
fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device))
history = {"g_loss": [], "d_loss": [], "d_real": [], "d_fake": [], "fid": {}}
best_fid = float("inf")
print(f"Device: {device} AMP: {use_amp} Batches/epoch: {len(loader)}")
for epoch in range(1, epochs + 1):
generator.train()
discriminator.train()
g_sum = d_sum = real_sum = fake_sum = 0.0
n_batches = 0
for imgs in tqdm(loader, desc=f"Epoch {epoch}/{epochs}", leave=False):
imgs = imgs.to(device)
bsz = imgs.size(0)
real_labels = torch.ones(bsz, device=device)
fake_labels = torch.zeros(bsz, device=device)
# ── Discriminator step ────────────────────────────────────────
noise = torch.randn(bsz, latent_dim, 1, 1, device=device)
with _autocast("cuda", enabled=use_amp):
fake = generator(noise).detach()
d_real = discriminator(imgs)
d_fake = discriminator(fake)
d_loss = bce(d_real, real_labels) + bce(d_fake, fake_labels)
opt_d.zero_grad()
scaler_d.scale(d_loss).backward()
scaler_d.step(opt_d)
scaler_d.update()
# ── Generator step ────────────────────────────────────────────
noise = torch.randn(bsz, latent_dim, 1, 1, device=device)
with _autocast("cuda", enabled=use_amp):
g_loss = bce(discriminator(generator(noise)), real_labels)
opt_g.zero_grad()
scaler_g.scale(g_loss).backward()
scaler_g.step(opt_g)
scaler_g.update()
ema.update(generator)
g_sum += g_loss.item()
d_sum += d_loss.item()
real_sum += d_real.mean().item()
fake_sum += d_fake.mean().item()
n_batches += 1
avg_g = g_sum / n_batches
avg_d = d_sum / n_batches
avg_r = real_sum / n_batches
avg_f = fake_sum / n_batches
history["g_loss"].append(avg_g)
history["d_loss"].append(avg_d)
history["d_real"].append(avg_r)
history["d_fake"].append(avg_f)
print(
f"[{epoch:03d}/{epochs}] "
f"G: {avg_g:.4f} D: {avg_d:.4f} D(real): {avg_r:.4f} D(fake): {avg_f:.4f}"
)
if epoch % sample_interval == 0:
_save_samples(ema, samples_dir, epoch, latent_dim=latent_dim, device=device)
if epoch % fid_interval == 0:
generator.eval()
with torch.no_grad():
fake_imgs = torch.cat([
generator(torch.randn(64, latent_dim, 1, 1, device=device))
for _ in range(fid_n_real // 64 + 1)
])[:fid_n_real]
fid_score = fid_eval.compute(fake_imgs)
history["fid"][epoch] = fid_score
print(f" FID @ epoch {epoch}: {fid_score:.2f}")
if fid_score < best_fid:
best_fid = fid_score
torch.save(generator.state_dict(), save_dir / f"{run_name}_best_g.pt")
torch.save(ema.model.state_dict(), save_dir / f"{run_name}_best_ema.pt")
torch.save(generator.state_dict(), save_dir / f"{run_name}_final_g.pt")
torch.save(discriminator.state_dict(), save_dir / f"{run_name}_final_d.pt")
torch.save(ema.model.state_dict(), save_dir / f"{run_name}_final_ema.pt")
return history