Testing VAE until it works - v1
This commit is contained in:
@@ -11,14 +11,15 @@ from torchmetrics.image.fid import FrechetInceptionDistance
|
|||||||
|
|
||||||
|
|
||||||
class FIDEvaluator:
|
class FIDEvaluator:
|
||||||
def __init__(self, real_dataset, n_real: int = 10_000, device: str = "cuda"):
|
def __init__(self, real_dataset, n_real: int = 10_000, device: str = "cuda",
|
||||||
|
num_workers: int = 2):
|
||||||
self.device = torch.device(device if torch.cuda.is_available() else "cpu")
|
self.device = torch.device(device if torch.cuda.is_available() else "cpu")
|
||||||
self.n_real = n_real
|
self.n_real = n_real
|
||||||
|
|
||||||
# Cache real images as a CPU tensor ([-1, 1] range)
|
# Cache real images as a CPU tensor ([-1, 1] range)
|
||||||
imgs_list = []
|
imgs_list = []
|
||||||
loader = DataLoader(real_dataset, batch_size=256, shuffle=False,
|
loader = DataLoader(real_dataset, batch_size=256, shuffle=False,
|
||||||
num_workers=4, drop_last=False)
|
num_workers=num_workers, drop_last=False)
|
||||||
for batch in loader:
|
for batch in loader:
|
||||||
imgs_list.append(batch.cpu())
|
imgs_list.append(batch.cpu())
|
||||||
if sum(x.size(0) for x in imgs_list) >= n_real:
|
if sum(x.size(0) for x in imgs_list) >= n_real:
|
||||||
|
|||||||
@@ -86,7 +86,8 @@ def train_dcgan(
|
|||||||
save_dir.mkdir(parents=True, exist_ok=True)
|
save_dir.mkdir(parents=True, exist_ok=True)
|
||||||
samples_dir = save_dir.parent / "samples" / run_name
|
samples_dir = save_dir.parent / "samples" / run_name
|
||||||
|
|
||||||
fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device))
|
fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device),
|
||||||
|
num_workers=cfg.get("num_workers", min(4, os.cpu_count() or 1)))
|
||||||
|
|
||||||
history = {"g_loss": [], "d_loss": [], "d_real": [], "d_fake": [], "fid": {}}
|
history = {"g_loss": [], "d_loss": [], "d_real": [], "d_fake": [], "fid": {}}
|
||||||
best_fid = float("inf")
|
best_fid = float("inf")
|
||||||
@@ -257,7 +258,8 @@ def train_wgan(
|
|||||||
save_dir.mkdir(parents=True, exist_ok=True)
|
save_dir.mkdir(parents=True, exist_ok=True)
|
||||||
samples_dir = save_dir.parent / "samples" / run_name
|
samples_dir = save_dir.parent / "samples" / run_name
|
||||||
|
|
||||||
fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device))
|
fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device),
|
||||||
|
num_workers=cfg.get("num_workers", min(4, os.cpu_count() or 1)))
|
||||||
|
|
||||||
history = {"g_loss": [], "w_dist": [], "gp": [], "d_real": [], "d_fake": [], "fid": {}}
|
history = {"g_loss": [], "w_dist": [], "gp": [], "d_real": [], "d_fake": [], "fid": {}}
|
||||||
best_fid = float("inf")
|
best_fid = float("inf")
|
||||||
@@ -497,7 +499,8 @@ def train_vae(
|
|||||||
save_dir.mkdir(parents=True, exist_ok=True)
|
save_dir.mkdir(parents=True, exist_ok=True)
|
||||||
samples_dir = save_dir.parent / "samples" / run_name
|
samples_dir = save_dir.parent / "samples" / run_name
|
||||||
|
|
||||||
fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device))
|
fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device),
|
||||||
|
num_workers=cfg.get("num_workers", min(4, os.cpu_count() or 1)))
|
||||||
|
|
||||||
history = {
|
history = {
|
||||||
"recon_loss": [], "kl_loss": [], "perc_loss": [],
|
"recon_loss": [], "kl_loss": [], "perc_loss": [],
|
||||||
@@ -697,7 +700,8 @@ def train_ddpm(
|
|||||||
save_dir.mkdir(parents=True, exist_ok=True)
|
save_dir.mkdir(parents=True, exist_ok=True)
|
||||||
samples_dir = save_dir.parent / "samples" / run_name
|
samples_dir = save_dir.parent / "samples" / run_name
|
||||||
|
|
||||||
fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device))
|
fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device),
|
||||||
|
num_workers=cfg.get("num_workers", min(4, os.cpu_count() or 1)))
|
||||||
|
|
||||||
history = {"loss": [], "fid": {}}
|
history = {"loss": [], "fid": {}}
|
||||||
best_fid = float("inf")
|
best_fid = float("inf")
|
||||||
|
|||||||
Reference in New Issue
Block a user