Preview of phase 2-5 implementation; needs a full check
This commit is contained in:
+27
-2
@@ -32,7 +32,7 @@ def main(config_path, *, data_dir_override=None, output_root="generator/outputs"
|
||||
import torch
|
||||
from src.data import GeneratorDataset, get_transform
|
||||
from src.models import get_model
|
||||
from src.training import train_dcgan
|
||||
from src.training import train_dcgan, train_wgan, train_vae, train_ddpm
|
||||
from src.utils import load_config
|
||||
|
||||
cfg = load_config(config_path)
|
||||
@@ -50,6 +50,13 @@ def main(config_path, *, data_dir_override=None, output_root="generator/outputs"
|
||||
|
||||
model, kind = get_model(cfg)
|
||||
|
||||
# Count total trainable parameters
|
||||
if isinstance(model, tuple):
|
||||
n_params = sum(p.numel() for p in model[0].parameters() if p.requires_grad)
|
||||
else:
|
||||
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
print(f"Trainable params: {n_params:,}")
|
||||
|
||||
augment = cfg.get("augment", True)
|
||||
transform = get_transform(cfg.get("image_size", 128), augment=augment)
|
||||
dataset = GeneratorDataset(
|
||||
@@ -66,13 +73,31 @@ def main(config_path, *, data_dir_override=None, output_root="generator/outputs"
|
||||
generator, discriminator, dataset, cfg,
|
||||
save_dir=models_dir, run_name=run_name, device=device,
|
||||
)
|
||||
elif kind == "wgan":
|
||||
generator, critic = model
|
||||
history = train_wgan(
|
||||
generator, critic, dataset, cfg,
|
||||
save_dir=models_dir, run_name=run_name, device=device,
|
||||
)
|
||||
elif kind == "vae":
|
||||
history = train_vae(
|
||||
model, dataset, cfg,
|
||||
save_dir=models_dir, run_name=run_name, device=device,
|
||||
)
|
||||
elif kind == "ddpm":
|
||||
history = train_ddpm(
|
||||
model, dataset, cfg,
|
||||
save_dir=models_dir, run_name=run_name, device=device,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"kind={kind!r} not yet implemented in this phase")
|
||||
|
||||
logs_dir.mkdir(parents=True, exist_ok=True)
|
||||
out = logs_dir / f"{run_name}.json"
|
||||
log_data = {"run_name": run_name, "config": cfg, "history": history}
|
||||
log_data["n_params"] = n_params
|
||||
with open(out, "w") as f:
|
||||
json.dump({"run_name": run_name, "config": cfg, "history": history}, f, indent=2)
|
||||
json.dump(log_data, f, indent=2)
|
||||
print(f"\nSaved log to {out}")
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user