Testing VAE until it works - v1
This commit is contained in:
@@ -86,7 +86,8 @@ def train_dcgan(
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
samples_dir = save_dir.parent / "samples" / run_name
|
||||
|
||||
fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device))
|
||||
fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device),
|
||||
num_workers=cfg.get("num_workers", min(4, os.cpu_count() or 1)))
|
||||
|
||||
history = {"g_loss": [], "d_loss": [], "d_real": [], "d_fake": [], "fid": {}}
|
||||
best_fid = float("inf")
|
||||
@@ -257,7 +258,8 @@ def train_wgan(
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
samples_dir = save_dir.parent / "samples" / run_name
|
||||
|
||||
fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device))
|
||||
fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device),
|
||||
num_workers=cfg.get("num_workers", min(4, os.cpu_count() or 1)))
|
||||
|
||||
history = {"g_loss": [], "w_dist": [], "gp": [], "d_real": [], "d_fake": [], "fid": {}}
|
||||
best_fid = float("inf")
|
||||
@@ -497,7 +499,8 @@ def train_vae(
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
samples_dir = save_dir.parent / "samples" / run_name
|
||||
|
||||
fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device))
|
||||
fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device),
|
||||
num_workers=cfg.get("num_workers", min(4, os.cpu_count() or 1)))
|
||||
|
||||
history = {
|
||||
"recon_loss": [], "kl_loss": [], "perc_loss": [],
|
||||
@@ -697,7 +700,8 @@ def train_ddpm(
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
samples_dir = save_dir.parent / "samples" / run_name
|
||||
|
||||
fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device))
|
||||
fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device),
|
||||
num_workers=cfg.get("num_workers", min(4, os.cpu_count() or 1)))
|
||||
|
||||
history = {"loss": [], "fid": {}}
|
||||
best_fid = float("inf")
|
||||
|
||||
Reference in New Issue
Block a user