Testing VAE until it works - v1
This commit is contained in:
+1
-1
@@ -52,7 +52,7 @@ def main(config_path, *, data_dir_override=None, output_root="generator/outputs"
|
||||
|
||||
# Count total trainable parameters
|
||||
if isinstance(model, tuple):
|
||||
n_params = sum(p.numel() for p in model[0].parameters() if p.requires_grad)
|
||||
n_params = sum(p.numel() for m in model for p in m.parameters() if p.requires_grad)
|
||||
else:
|
||||
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
print(f"Trainable params: {n_params:,}")
|
||||
|
||||
Reference in New Issue
Block a user