From 5ccb106edbb8d43d186a75e4a0e5b52afd6b13e3 Mon Sep 17 00:00:00 2001 From: Johnny Fernandes Date: Sat, 2 May 2026 14:06:31 +0100 Subject: [PATCH] Testing VAE until it works - v1 --- generator/src/models/vae.py | 12 +++++++++--- generator/src/training/ema.py | 2 ++ 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/generator/src/models/vae.py b/generator/src/models/vae.py index 55ffa86..126a08b 100644 --- a/generator/src/models/vae.py +++ b/generator/src/models/vae.py @@ -22,17 +22,21 @@ def _init_weights(m): nn.init.normal_(m.weight, 0.0, 0.02) if m.bias is not None: nn.init.zeros_(m.bias) - elif isinstance(m, nn.BatchNorm2d) and m.weight is not None: + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)) and m.weight is not None: nn.init.normal_(m.weight, 1.0, 0.02) nn.init.zeros_(m.bias) +def _norm(channels: int) -> nn.GroupNorm: + return nn.GroupNorm(8, channels) + + def _upsample_block(in_ch: int, out_ch: int) -> nn.Sequential: """Nearest-neighbour upsample followed by a 3×3 conv — no checkerboard.""" return nn.Sequential( nn.Upsample(scale_factor=2, mode="nearest"), nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False), - nn.BatchNorm2d(out_ch), + _norm(out_ch), nn.ReLU(inplace=True), ) @@ -65,7 +69,7 @@ class VAE(nn.Module): for _ in range(n_down - 1): enc_layers += [ nn.Conv2d(ch, ch * 2, 4, stride=2, padding=1, bias=False), - nn.BatchNorm2d(ch * 2), + _norm(ch * 2), nn.LeakyReLU(0.2, inplace=True), ] ch *= 2 @@ -102,6 +106,8 @@ class VAE(nn.Module): return self.fc_mu(h), log_var def reparameterize(self, mu: torch.Tensor, log_var: torch.Tensor) -> torch.Tensor: + if not self.training: + return mu std = torch.exp(0.5 * log_var) return mu + std * torch.randn_like(std) diff --git a/generator/src/training/ema.py b/generator/src/training/ema.py index 4782787..97f5bd8 100644 --- a/generator/src/training/ema.py +++ b/generator/src/training/ema.py @@ -20,3 +20,5 @@ class EMA: def update(self, model: nn.Module) -> None: for p_ema, p in zip(self.model.parameters(), model.parameters()): p_ema.data.mul_(self.decay).add_(p.data, alpha=1.0 - self.decay) + for b_ema, b in zip(self.model.buffers(), model.buffers()): + b_ema.copy_(b)