Testing VAE until it works - v1
This commit is contained in:
@@ -11,14 +11,15 @@ from torchmetrics.image.fid import FrechetInceptionDistance
|
||||
|
||||
|
||||
class FIDEvaluator:
|
||||
def __init__(self, real_dataset, n_real: int = 10_000, device: str = "cuda"):
|
||||
def __init__(self, real_dataset, n_real: int = 10_000, device: str = "cuda",
|
||||
num_workers: int = 2):
|
||||
self.device = torch.device(device if torch.cuda.is_available() else "cpu")
|
||||
self.n_real = n_real
|
||||
|
||||
# Cache real images as a CPU tensor ([-1, 1] range)
|
||||
imgs_list = []
|
||||
loader = DataLoader(real_dataset, batch_size=256, shuffle=False,
|
||||
num_workers=4, drop_last=False)
|
||||
num_workers=num_workers, drop_last=False)
|
||||
for batch in loader:
|
||||
imgs_list.append(batch.cpu())
|
||||
if sum(x.size(0) for x in imgs_list) >= n_real:
|
||||
|
||||
@@ -86,7 +86,8 @@ def train_dcgan(
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
samples_dir = save_dir.parent / "samples" / run_name
|
||||
|
||||
fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device))
|
||||
fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device),
|
||||
num_workers=cfg.get("num_workers", min(4, os.cpu_count() or 1)))
|
||||
|
||||
history = {"g_loss": [], "d_loss": [], "d_real": [], "d_fake": [], "fid": {}}
|
||||
best_fid = float("inf")
|
||||
@@ -257,7 +258,8 @@ def train_wgan(
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
samples_dir = save_dir.parent / "samples" / run_name
|
||||
|
||||
fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device))
|
||||
fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device),
|
||||
num_workers=cfg.get("num_workers", min(4, os.cpu_count() or 1)))
|
||||
|
||||
history = {"g_loss": [], "w_dist": [], "gp": [], "d_real": [], "d_fake": [], "fid": {}}
|
||||
best_fid = float("inf")
|
||||
@@ -497,7 +499,8 @@ def train_vae(
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
samples_dir = save_dir.parent / "samples" / run_name
|
||||
|
||||
fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device))
|
||||
fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device),
|
||||
num_workers=cfg.get("num_workers", min(4, os.cpu_count() or 1)))
|
||||
|
||||
history = {
|
||||
"recon_loss": [], "kl_loss": [], "perc_loss": [],
|
||||
@@ -697,7 +700,8 @@ def train_ddpm(
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
samples_dir = save_dir.parent / "samples" / run_name
|
||||
|
||||
fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device))
|
||||
fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device),
|
||||
num_workers=cfg.get("num_workers", min(4, os.cpu_count() or 1)))
|
||||
|
||||
history = {"loss": [], "fid": {}}
|
||||
best_fid = float("inf")
|
||||
|
||||
Reference in New Issue
Block a user