Testing VAE until it works - v1
This commit is contained in:
@@ -6,5 +6,6 @@
|
|||||||
"subsample": 1.0,
|
"subsample": 1.0,
|
||||||
"sample_interval": 10,
|
"sample_interval": 10,
|
||||||
"fid_interval": 25,
|
"fid_interval": 25,
|
||||||
"fid_n_real": 5000
|
"fid_n_real": 5000,
|
||||||
|
"num_workers": 2
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -66,7 +66,7 @@ def train_dcgan(
|
|||||||
|
|
||||||
loader = DataLoader(
|
loader = DataLoader(
|
||||||
train_dataset, batch_size=batch_size, shuffle=True,
|
train_dataset, batch_size=batch_size, shuffle=True,
|
||||||
num_workers=min(4, os.cpu_count() or 1),
|
num_workers=cfg.get("num_workers", min(4, os.cpu_count() or 1)),
|
||||||
pin_memory=(device.type == "cuda"), drop_last=True,
|
pin_memory=(device.type == "cuda"), drop_last=True,
|
||||||
)
|
)
|
||||||
opt_g = torch.optim.Adam(generator.parameters(), lr=lr_g, betas=(beta1, beta2))
|
opt_g = torch.optim.Adam(generator.parameters(), lr=lr_g, betas=(beta1, beta2))
|
||||||
@@ -239,7 +239,7 @@ def train_wgan(
|
|||||||
|
|
||||||
loader = DataLoader(
|
loader = DataLoader(
|
||||||
train_dataset, batch_size=batch_size, shuffle=True,
|
train_dataset, batch_size=batch_size, shuffle=True,
|
||||||
num_workers=min(4, os.cpu_count() or 1),
|
num_workers=cfg.get("num_workers", min(4, os.cpu_count() or 1)),
|
||||||
pin_memory=(device.type == "cuda"), drop_last=True,
|
pin_memory=(device.type == "cuda"), drop_last=True,
|
||||||
)
|
)
|
||||||
opt_g = torch.optim.Adam(generator.parameters(), lr=lr_g, betas=(beta1, beta2))
|
opt_g = torch.optim.Adam(generator.parameters(), lr=lr_g, betas=(beta1, beta2))
|
||||||
@@ -444,7 +444,7 @@ def train_vae(
|
|||||||
|
|
||||||
loader = DataLoader(
|
loader = DataLoader(
|
||||||
train_dataset, batch_size=batch_size, shuffle=True,
|
train_dataset, batch_size=batch_size, shuffle=True,
|
||||||
num_workers=min(4, os.cpu_count() or 1),
|
num_workers=cfg.get("num_workers", min(4, os.cpu_count() or 1)),
|
||||||
pin_memory=(device.type == "cuda"), drop_last=True,
|
pin_memory=(device.type == "cuda"), drop_last=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -680,7 +680,7 @@ def train_ddpm(
|
|||||||
|
|
||||||
loader = DataLoader(
|
loader = DataLoader(
|
||||||
train_dataset, batch_size=batch_size, shuffle=True,
|
train_dataset, batch_size=batch_size, shuffle=True,
|
||||||
num_workers=min(4, os.cpu_count() or 1),
|
num_workers=cfg.get("num_workers", min(4, os.cpu_count() or 1)),
|
||||||
pin_memory=(device.type == "cuda"), drop_last=True,
|
pin_memory=(device.type == "cuda"), drop_last=True,
|
||||||
)
|
)
|
||||||
opt = torch.optim.AdamW(model.parameters(), lr=lr)
|
opt = torch.optim.AdamW(model.parameters(), lr=lr)
|
||||||
|
|||||||
Reference in New Issue
Block a user