diff --git a/generator/src/training/trainer.py b/generator/src/training/trainer.py index 10ad7fb..9e501fa 100644 --- a/generator/src/training/trainer.py +++ b/generator/src/training/trainer.py @@ -276,8 +276,8 @@ def train_wgan( ema = EMA(generator, decay=ema_decay) if hasattr(torch, "compile"): + # critic excluded — GP requires double backward which torch.compile doesn't support generator = torch.compile(generator) - critic = torch.compile(critic) # Fixed noise for consistent sample tracking across epochs fixed_noise = torch.randn(16, latent_dim, 1, 1, device=device)