Testing VAE until it works - v1

This commit is contained in:
Johnny Fernandes
2026-05-02 13:24:17 +01:00
parent ec8d4ae336
commit 1bed6f0d70
2 changed files with 6 additions and 5 deletions
+4 -4
View File
@@ -66,7 +66,7 @@ def train_dcgan(
loader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=True,
num_workers=min(4, os.cpu_count() or 1),
num_workers=cfg.get("num_workers", min(4, os.cpu_count() or 1)),
pin_memory=(device.type == "cuda"), drop_last=True,
)
opt_g = torch.optim.Adam(generator.parameters(), lr=lr_g, betas=(beta1, beta2))
@@ -239,7 +239,7 @@ def train_wgan(
loader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=True,
num_workers=min(4, os.cpu_count() or 1),
num_workers=cfg.get("num_workers", min(4, os.cpu_count() or 1)),
pin_memory=(device.type == "cuda"), drop_last=True,
)
opt_g = torch.optim.Adam(generator.parameters(), lr=lr_g, betas=(beta1, beta2))
@@ -444,7 +444,7 @@ def train_vae(
loader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=True,
num_workers=min(4, os.cpu_count() or 1),
num_workers=cfg.get("num_workers", min(4, os.cpu_count() or 1)),
pin_memory=(device.type == "cuda"), drop_last=True,
)
@@ -680,7 +680,7 @@ def train_ddpm(
loader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=True,
num_workers=min(4, os.cpu_count() or 1),
num_workers=cfg.get("num_workers", min(4, os.cpu_count() or 1)),
pin_memory=(device.type == "cuda"), drop_last=True,
)
opt = torch.optim.AdamW(model.parameters(), lr=lr)