Testing VAE until it works - v1
This commit is contained in:
@@ -22,17 +22,21 @@ def _init_weights(m):
|
|||||||
nn.init.normal_(m.weight, 0.0, 0.02)
|
nn.init.normal_(m.weight, 0.0, 0.02)
|
||||||
if m.bias is not None:
|
if m.bias is not None:
|
||||||
nn.init.zeros_(m.bias)
|
nn.init.zeros_(m.bias)
|
||||||
elif isinstance(m, nn.BatchNorm2d) and m.weight is not None:
|
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)) and m.weight is not None:
|
||||||
nn.init.normal_(m.weight, 1.0, 0.02)
|
nn.init.normal_(m.weight, 1.0, 0.02)
|
||||||
nn.init.zeros_(m.bias)
|
nn.init.zeros_(m.bias)
|
||||||
|
|
||||||
|
|
||||||
|
def _norm(channels: int) -> nn.GroupNorm:
|
||||||
|
return nn.GroupNorm(8, channels)
|
||||||
|
|
||||||
|
|
||||||
def _upsample_block(in_ch: int, out_ch: int) -> nn.Sequential:
|
def _upsample_block(in_ch: int, out_ch: int) -> nn.Sequential:
|
||||||
"""Nearest-neighbour upsample followed by a 3×3 conv — no checkerboard."""
|
"""Nearest-neighbour upsample followed by a 3×3 conv — no checkerboard."""
|
||||||
return nn.Sequential(
|
return nn.Sequential(
|
||||||
nn.Upsample(scale_factor=2, mode="nearest"),
|
nn.Upsample(scale_factor=2, mode="nearest"),
|
||||||
nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
|
nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
|
||||||
nn.BatchNorm2d(out_ch),
|
_norm(out_ch),
|
||||||
nn.ReLU(inplace=True),
|
nn.ReLU(inplace=True),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -65,7 +69,7 @@ class VAE(nn.Module):
|
|||||||
for _ in range(n_down - 1):
|
for _ in range(n_down - 1):
|
||||||
enc_layers += [
|
enc_layers += [
|
||||||
nn.Conv2d(ch, ch * 2, 4, stride=2, padding=1, bias=False),
|
nn.Conv2d(ch, ch * 2, 4, stride=2, padding=1, bias=False),
|
||||||
nn.BatchNorm2d(ch * 2),
|
_norm(ch * 2),
|
||||||
nn.LeakyReLU(0.2, inplace=True),
|
nn.LeakyReLU(0.2, inplace=True),
|
||||||
]
|
]
|
||||||
ch *= 2
|
ch *= 2
|
||||||
@@ -102,6 +106,8 @@ class VAE(nn.Module):
|
|||||||
return self.fc_mu(h), log_var
|
return self.fc_mu(h), log_var
|
||||||
|
|
||||||
def reparameterize(self, mu: torch.Tensor, log_var: torch.Tensor) -> torch.Tensor:
|
def reparameterize(self, mu: torch.Tensor, log_var: torch.Tensor) -> torch.Tensor:
|
||||||
|
if not self.training:
|
||||||
|
return mu
|
||||||
std = torch.exp(0.5 * log_var)
|
std = torch.exp(0.5 * log_var)
|
||||||
return mu + std * torch.randn_like(std)
|
return mu + std * torch.randn_like(std)
|
||||||
|
|
||||||
|
|||||||
@@ -20,3 +20,5 @@ class EMA:
|
|||||||
def update(self, model: nn.Module) -> None:
|
def update(self, model: nn.Module) -> None:
|
||||||
for p_ema, p in zip(self.model.parameters(), model.parameters()):
|
for p_ema, p in zip(self.model.parameters(), model.parameters()):
|
||||||
p_ema.data.mul_(self.decay).add_(p.data, alpha=1.0 - self.decay)
|
p_ema.data.mul_(self.decay).add_(p.data, alpha=1.0 - self.decay)
|
||||||
|
for b_ema, b in zip(self.model.buffers(), model.buffers()):
|
||||||
|
b_ema.copy_(b)
|
||||||
|
|||||||
Reference in New Issue
Block a user