From 1bed6f0d70f60ceca9352386fdb9ef1cad464d54 Mon Sep 17 00:00:00 2001 From: Johnny Fernandes Date: Sat, 2 May 2026 13:24:17 +0100 Subject: [PATCH] Testing VAE until it works - v1 --- generator/configs/shared.json | 3 ++- generator/src/training/trainer.py | 8 ++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/generator/configs/shared.json b/generator/configs/shared.json index 4ba1c9b..04865b8 100644 --- a/generator/configs/shared.json +++ b/generator/configs/shared.json @@ -6,5 +6,6 @@ "subsample": 1.0, "sample_interval": 10, "fid_interval": 25, - "fid_n_real": 5000 + "fid_n_real": 5000, + "num_workers": 2 } diff --git a/generator/src/training/trainer.py b/generator/src/training/trainer.py index 688366c..75d05cb 100644 --- a/generator/src/training/trainer.py +++ b/generator/src/training/trainer.py @@ -66,7 +66,7 @@ def train_dcgan( loader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True, - num_workers=min(4, os.cpu_count() or 1), + num_workers=cfg.get("num_workers", min(4, os.cpu_count() or 1)), pin_memory=(device.type == "cuda"), drop_last=True, ) opt_g = torch.optim.Adam(generator.parameters(), lr=lr_g, betas=(beta1, beta2)) @@ -239,7 +239,7 @@ def train_wgan( loader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True, - num_workers=min(4, os.cpu_count() or 1), + num_workers=cfg.get("num_workers", min(4, os.cpu_count() or 1)), pin_memory=(device.type == "cuda"), drop_last=True, ) opt_g = torch.optim.Adam(generator.parameters(), lr=lr_g, betas=(beta1, beta2)) @@ -444,7 +444,7 @@ def train_vae( loader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True, - num_workers=min(4, os.cpu_count() or 1), + num_workers=cfg.get("num_workers", min(4, os.cpu_count() or 1)), pin_memory=(device.type == "cuda"), drop_last=True, ) @@ -680,7 +680,7 @@ def train_ddpm( loader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True, - num_workers=min(4, os.cpu_count() or 1), + num_workers=cfg.get("num_workers", min(4, os.cpu_count() or 1)), pin_memory=(device.type == "cuda"), drop_last=True, ) opt = torch.optim.AdamW(model.parameters(), lr=lr)