Testing VAE until it works - v1

This commit is contained in:
Johnny Fernandes
2026-05-02 14:06:31 +01:00
parent 218123a845
commit 5ccb106edb
2 changed files with 11 additions and 3 deletions
+9 -3
View File
@@ -22,17 +22,21 @@ def _init_weights(m):
nn.init.normal_(m.weight, 0.0, 0.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.BatchNorm2d) and m.weight is not None:
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)) and m.weight is not None:
nn.init.normal_(m.weight, 1.0, 0.02)
nn.init.zeros_(m.bias)
def _norm(channels: int) -> nn.GroupNorm:
return nn.GroupNorm(8, channels)
def _upsample_block(in_ch: int, out_ch: int) -> nn.Sequential:
"""Nearest-neighbour upsample followed by a 3×3 conv — no checkerboard."""
return nn.Sequential(
nn.Upsample(scale_factor=2, mode="nearest"),
nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
nn.BatchNorm2d(out_ch),
_norm(out_ch),
nn.ReLU(inplace=True),
)
@@ -65,7 +69,7 @@ class VAE(nn.Module):
for _ in range(n_down - 1):
enc_layers += [
nn.Conv2d(ch, ch * 2, 4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(ch * 2),
_norm(ch * 2),
nn.LeakyReLU(0.2, inplace=True),
]
ch *= 2
@@ -102,6 +106,8 @@ class VAE(nn.Module):
return self.fc_mu(h), log_var
def reparameterize(self, mu: torch.Tensor, log_var: torch.Tensor) -> torch.Tensor:
if not self.training:
return mu
std = torch.exp(0.5 * log_var)
return mu + std * torch.randn_like(std)
+2
View File
@@ -20,3 +20,5 @@ class EMA:
def update(self, model: nn.Module) -> None:
for p_ema, p in zip(self.model.parameters(), model.parameters()):
p_ema.data.mul_(self.decay).add_(p.data, alpha=1.0 - self.decay)
for b_ema, b in zip(self.model.buffers(), model.buffers()):
b_ema.copy_(b)