Testing VAE until it works - v1

This commit is contained in:
Johnny Fernandes
2026-05-02 13:26:39 +01:00
parent 1bed6f0d70
commit 218123a845
2 changed files with 11 additions and 6 deletions
+3 -2
View File
@@ -11,14 +11,15 @@ from torchmetrics.image.fid import FrechetInceptionDistance
class FIDEvaluator: class FIDEvaluator:
def __init__(self, real_dataset, n_real: int = 10_000, device: str = "cuda"): def __init__(self, real_dataset, n_real: int = 10_000, device: str = "cuda",
num_workers: int = 2):
self.device = torch.device(device if torch.cuda.is_available() else "cpu") self.device = torch.device(device if torch.cuda.is_available() else "cpu")
self.n_real = n_real self.n_real = n_real
# Cache real images as a CPU tensor ([-1, 1] range) # Cache real images as a CPU tensor ([-1, 1] range)
imgs_list = [] imgs_list = []
loader = DataLoader(real_dataset, batch_size=256, shuffle=False, loader = DataLoader(real_dataset, batch_size=256, shuffle=False,
num_workers=4, drop_last=False) num_workers=num_workers, drop_last=False)
for batch in loader: for batch in loader:
imgs_list.append(batch.cpu()) imgs_list.append(batch.cpu())
if sum(x.size(0) for x in imgs_list) >= n_real: if sum(x.size(0) for x in imgs_list) >= n_real:
+8 -4
View File
@@ -86,7 +86,8 @@ def train_dcgan(
save_dir.mkdir(parents=True, exist_ok=True) save_dir.mkdir(parents=True, exist_ok=True)
samples_dir = save_dir.parent / "samples" / run_name samples_dir = save_dir.parent / "samples" / run_name
fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device)) fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device),
num_workers=cfg.get("num_workers", min(4, os.cpu_count() or 1)))
history = {"g_loss": [], "d_loss": [], "d_real": [], "d_fake": [], "fid": {}} history = {"g_loss": [], "d_loss": [], "d_real": [], "d_fake": [], "fid": {}}
best_fid = float("inf") best_fid = float("inf")
@@ -257,7 +258,8 @@ def train_wgan(
save_dir.mkdir(parents=True, exist_ok=True) save_dir.mkdir(parents=True, exist_ok=True)
samples_dir = save_dir.parent / "samples" / run_name samples_dir = save_dir.parent / "samples" / run_name
fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device)) fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device),
num_workers=cfg.get("num_workers", min(4, os.cpu_count() or 1)))
history = {"g_loss": [], "w_dist": [], "gp": [], "d_real": [], "d_fake": [], "fid": {}} history = {"g_loss": [], "w_dist": [], "gp": [], "d_real": [], "d_fake": [], "fid": {}}
best_fid = float("inf") best_fid = float("inf")
@@ -497,7 +499,8 @@ def train_vae(
save_dir.mkdir(parents=True, exist_ok=True) save_dir.mkdir(parents=True, exist_ok=True)
samples_dir = save_dir.parent / "samples" / run_name samples_dir = save_dir.parent / "samples" / run_name
fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device)) fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device),
num_workers=cfg.get("num_workers", min(4, os.cpu_count() or 1)))
history = { history = {
"recon_loss": [], "kl_loss": [], "perc_loss": [], "recon_loss": [], "kl_loss": [], "perc_loss": [],
@@ -697,7 +700,8 @@ def train_ddpm(
save_dir.mkdir(parents=True, exist_ok=True) save_dir.mkdir(parents=True, exist_ok=True)
samples_dir = save_dir.parent / "samples" / run_name samples_dir = save_dir.parent / "samples" / run_name
fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device)) fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device),
num_workers=cfg.get("num_workers", min(4, os.cpu_count() or 1)))
history = {"loss": [], "fid": {}} history = {"loss": [], "fid": {}}
best_fid = float("inf") best_fid = float("inf")