From 218123a845c1be054c4a70f0cf9075ec534eb4aa Mon Sep 17 00:00:00 2001 From: Johnny Fernandes Date: Sat, 2 May 2026 13:26:39 +0100 Subject: [PATCH] Testing VAE until it works - v1 --- generator/src/training/fid.py | 5 +++-- generator/src/training/trainer.py | 12 ++++++++---- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/generator/src/training/fid.py b/generator/src/training/fid.py index 2d0c842..206f0b7 100644 --- a/generator/src/training/fid.py +++ b/generator/src/training/fid.py @@ -11,14 +11,15 @@ from torchmetrics.image.fid import FrechetInceptionDistance class FIDEvaluator: - def __init__(self, real_dataset, n_real: int = 10_000, device: str = "cuda"): + def __init__(self, real_dataset, n_real: int = 10_000, device: str = "cuda", + num_workers: int = 2): self.device = torch.device(device if torch.cuda.is_available() else "cpu") self.n_real = n_real # Cache real images as a CPU tensor ([-1, 1] range) imgs_list = [] loader = DataLoader(real_dataset, batch_size=256, shuffle=False, - num_workers=4, drop_last=False) + num_workers=num_workers, drop_last=False) for batch in loader: imgs_list.append(batch.cpu()) if sum(x.size(0) for x in imgs_list) >= n_real: diff --git a/generator/src/training/trainer.py b/generator/src/training/trainer.py index 75d05cb..ce08a6d 100644 --- a/generator/src/training/trainer.py +++ b/generator/src/training/trainer.py @@ -86,7 +86,8 @@ def train_dcgan( save_dir.mkdir(parents=True, exist_ok=True) samples_dir = save_dir.parent / "samples" / run_name - fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device)) + fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device), + num_workers=cfg.get("num_workers", min(4, os.cpu_count() or 1))) history = {"g_loss": [], "d_loss": [], "d_real": [], "d_fake": [], "fid": {}} best_fid = float("inf") @@ -257,7 +258,8 @@ def train_wgan( save_dir.mkdir(parents=True, exist_ok=True) samples_dir = save_dir.parent / "samples" / run_name - fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device)) + fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device), + num_workers=cfg.get("num_workers", min(4, os.cpu_count() or 1))) history = {"g_loss": [], "w_dist": [], "gp": [], "d_real": [], "d_fake": [], "fid": {}} best_fid = float("inf") @@ -497,7 +499,8 @@ def train_vae( save_dir.mkdir(parents=True, exist_ok=True) samples_dir = save_dir.parent / "samples" / run_name - fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device)) + fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device), + num_workers=cfg.get("num_workers", min(4, os.cpu_count() or 1))) history = { "recon_loss": [], "kl_loss": [], "perc_loss": [], @@ -697,7 +700,8 @@ def train_ddpm( save_dir.mkdir(parents=True, exist_ok=True) samples_dir = save_dir.parent / "samples" / run_name - fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device)) + fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device), + num_workers=cfg.get("num_workers", min(4, os.cpu_count() or 1))) history = {"loss": [], "fid": {}} best_fid = float("inf")