@@ -403,13 +403,20 @@ def train_vae(
run_name : str ,
device : str = " cuda " ,
) - > dict :
""" VAE training loop covering Phase 3.1 – 3.3.
""" VAE training loop covering Phase 3.1 – 3.3 and Phase 5 .
Config toggles:
lambda_perceptual > 0 → VGG-16 perceptual loss (Phase 3.2+)
lambda_adversarial > 0 → PatchGAN hinge adversarial loss (Phase 3.3)
free_bits > 0 → per-dimension KL free bits (prevents posterior
collapse and KL explosion)
Loss: L = L_mse + λ_perc·L_vgg + λ_adv·L_adv + β_kl·L_kl
AMP is intentionally disabled for VAE training — mixed-precision float16
overflows when the KL divergence spikes, producing NaN cascades that
corrupt the model irrecoverably. All VAE + perceptual + PatchGAN
computation runs in float32.
"""
device = torch . device ( device if torch . cuda . is_available ( ) else " cpu " )
vae = vae . to ( device )
@@ -425,6 +432,8 @@ def train_vae(
lambda_perceptual = cfg . get ( " lambda_perceptual " , 0.0 )
lambda_adversarial = cfg . get ( " lambda_adversarial " , 0.0 )
lr_d = cfg . get ( " lr_d " , 1e-4 )
free_bits_val = cfg . get ( " free_bits " , 0.0 )
grad_clip = cfg . get ( " grad_clip " , 1.0 )
ema_decay = cfg . get ( " ema_decay " , 0.9999 )
sample_interval = cfg . get ( " sample_interval " , 10 )
fid_interval = cfg . get ( " fid_interval " , 25 )
@@ -432,6 +441,7 @@ def train_vae(
use_perceptual = lambda_perceptual > 0
use_adversarial = lambda_adversarial > 0
use_free_bits = free_bits_val > 0
loader = DataLoader (
train_dataset , batch_size = batch_size , shuffle = True ,
@@ -440,8 +450,8 @@ def train_vae(
)
opt_vae = torch . optim . Adam ( vae . parameters ( ) , lr = lr )
use_amp = device . type == " cuda "
scaler = _GradScaler ( " cuda " , enabled = use_amp )
# AMP disabled — float16 overflows on KL spikes, causing NaN cascades
use_amp = False
# KL warmup: linearly ramp beta_kl from 0 to target over first 20% of training
kl_warmup_epochs = max ( 1 , epochs / / 5 )
@@ -456,11 +466,10 @@ def train_vae(
perc_fn = None
patchgan = None
opt_d = None
scaler_d = None
if use_perceptual :
from src . training . perceptual import PerceptualLoss
perc_fn = PerceptualLoss ( ) . to ( device )
perc_fn = PerceptualLoss ( ) . to ( device ) . float ( )
print ( " Perceptual loss: VGG-16 relu1_2 + relu2_2 + relu3_3 " )
if use_adversarial :
@@ -468,15 +477,14 @@ def train_vae(
patchgan = PatchGANDiscriminator (
ndf = cfg . get ( " ndf_patch " , 64 ) ,
image_size = cfg . get ( " image_size " , 64 ) ,
) . to ( device )
) . to ( device ) . float ( )
opt_d = torch . optim . Adam ( patchgan . parameters ( ) , lr = lr_d , betas = ( 0.5 , 0.999 ) )
scaler_d = _GradScaler ( " cuda " , enabled = use_amp )
sched_d = torch . optim . lr_scheduler . LambdaLR (
opt_d , lr_lambda = lambda ep : max ( 0.0 , 1.0 - max ( ep - decay_start , 0 ) / max ( epochs - decay_start , 1 ) ) )
n_d = sum ( p . numel ( ) for p in patchgan . parameters ( ) )
print ( f " PatchGAN: { n_d : , } params " )
else :
hinge_d_loss = hinge_g_loss = None # satisfy linter, never called
hinge_d_loss = hinge_g_loss = None # never called
# ── Fixed seeds for consistent visualisation ──────────────────────────
fixed_z = torch . randn ( 16 , latent_dim , device = device )
@@ -497,9 +505,11 @@ def train_vae(
" adv_g_loss " : [ ] , " adv_d_loss " : [ ] , " fid " : { } ,
}
best_fid = float ( " inf " )
nan_skipped = 0
print (
f " Device: { device } AMP: { use_amp } Batches/epoch: { len ( loader ) } "
f " β_kl= { beta_kl } (warmup { kl_warmup_epochs } ep) λ_perc= { lambda_perceptual } λ_adv= { lambda_adversarial } "
f " Device: { device } AMP: disabled (float32) Batches/epoch: { len ( loader ) } "
f " β_kl= { beta_kl } (warmup { kl_warmup_epochs } ep) λ_perc= { lambda_perceptual } "
f " λ_adv= { lambda_adversarial } free_bits= { free_bits_val } "
)
t_start = time . time ( )
@@ -513,43 +523,56 @@ def train_vae(
n_batches = 0
for real in tqdm ( loader , desc = f " Epoch { epoch } / { epochs } " , leave = False ) :
real = real . to ( device )
real = real . to ( device ) . float ( )
# KL warmup: ramp from 0 to beta_kl over kl_warmup_epochs
current_beta = beta_kl * min ( 1.0 , epoch / kl_warmup_epochs )
# ── VAE forward ─────────────────── ────────────────────────────
with _autocast ( " cuda " , enabled = use_amp ) :
# ── VAE forward (float32, no AMP) ────────────────────────────
recon , mu , log_var = vae ( real )
mse = F . mse_loss ( recon , real )
kl = - 0.5 * ( 1 + log_var - mu . pow ( 2 ) - log_var . exp ( ) ) . sum ( 1 ) . mean ( )
# KL divergence with optional free bits
kl_per_dim = - 0.5 * ( 1 + log_var - mu . pow ( 2 ) - log_var . exp ( ) ) # (B, latent_dim)
if use_free_bits :
# Free bits: ensure each dimension contributes at least free_bits_val KL.
# Dimensions below the threshold are raised to it, preventing posterior
# collapse (dimensions that go to 0) while still penalising large KL.
kl_per_dim = torch . clamp ( kl_per_dim , min = free_bits_val )
kl = kl_per_dim . sum ( 1 ) . mean ( )
perc = perc_fn ( recon , real ) if use_perceptual else real . new_zeros ( 1 ) . squeeze ( )
vae_loss = mse + current_beta * kl + lambda_perceptual * perc
# ── NaN/Inf guard ────────────────────────────────────────────
if not torch . isfinite ( vae_loss ) :
nan_skipped + = 1
opt_vae . zero_grad ( )
continue
# ── PatchGAN discriminator step ───────────────────────────────
adv_d = real . new_zeros ( 1 ) . squeeze ( )
if use_adversarial :
opt_d . zero_grad ( )
with _autocast ( " cuda " , enabled = use_amp ) :
d_real = patchgan ( real )
d_fake = patchgan ( recon . detach ( ) )
adv_d = hinge_d_loss ( d_real , d_fake )
scaler_d . scale ( adv_d ) . backward ( )
scaler_d . step ( opt_d )
scaler_d . update ( )
if torch . isfinite ( adv_d ) :
adv_d . backward ( )
torch . nn . utils . clip_grad_norm_ ( patchgan . parameters ( ) , grad_clip )
opt_d . step ( )
# ── PatchGAN generator adversarial loss ───────────────────────
adv_g = real . new_zeros ( 1 ) . squeeze ( )
if use_adversarial :
with _autocast ( " cuda " , enabled = use_amp ) :
adv_g = hinge_g_loss ( patchgan ( recon ) )
vae_loss = vae_loss + lambda_adversarial * adv_g
# ── VAE backward ──────────────────────────────────────────────
opt_vae . zero_grad ( )
scaler . scale ( vae_loss ) . backward ( )
scaler . step ( opt_vae )
scaler . update ( )
vae_loss . backward ( )
torch . nn . utils . clip_grad_norm_ ( vae . parameters ( ) , grad_clip )
opt_vae . step ( )
ema . update ( vae )
recon_sum + = mse . item ( )
@@ -559,11 +582,11 @@ def train_vae(
adv_d_sum + = adv_d . item ( )
n_batches + = 1
avg_r = recon_sum / n_batches
avg_k = kl_sum / n_batches
avg_p = perc_sum / n_batches
avg_g = adv_g_sum / n_batches
avg_d = adv_d_sum / n_batches
avg_r = recon_sum / max ( n_batches , 1 )
avg_k = kl_sum / max ( n_batches , 1 )
avg_p = perc_sum / max ( n_batches , 1 )
avg_g = adv_g_sum / max ( n_batches , 1 )
avg_d = adv_d_sum / max ( n_batches , 1 )
history [ " recon_loss " ] . append ( avg_r )
history [ " kl_loss " ] . append ( avg_k )
history [ " perc_loss " ] . append ( avg_p )
@@ -574,6 +597,7 @@ def train_vae(
f " [ { epoch : 03d } / { epochs } ] "
f " MSE: { avg_r : .4f } KL: { avg_k : .2f } β= { current_beta : .6f } "
f " Perc: { avg_p : .4f } AdvG: { avg_g : .4f } AdvD: { avg_d : .4f } "
f " (NaN skipped: { nan_skipped } ) "
)
if epoch % sample_interval == 0 :
@@ -607,6 +631,7 @@ def train_vae(
if patchgan is not None :
torch . save ( patchgan . state_dict ( ) , save_dir / f " { run_name } _final_patchgan.pt " )
history [ " train_time_s " ] = time . time ( ) - t_start
print ( f " Total NaN-skipped batches: { nan_skipped } " )
return history