Clean state
This commit is contained in:
@@ -0,0 +1,81 @@
|
||||
"""
|
||||
Train a generative model from a config file.
|
||||
|
||||
Usage:
|
||||
python run.py <config.json>
|
||||
python run.py <config.json> --data-dir /path/to/data --output-root generator/outputs
|
||||
"""
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
# Allow running from project root (python3 generator/run.py ...) or from inside generator/
|
||||
_here = Path(__file__).resolve().parent
|
||||
if str(_here) not in sys.path:
|
||||
sys.path.insert(0, str(_here))
|
||||
|
||||
warnings.filterwarnings("ignore", message="Corrupt EXIF data", category=UserWarning)
|
||||
|
||||
|
||||
def parse_args(argv=None):
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument("config_path")
|
||||
parser.add_argument("--data-dir", default=None)
|
||||
parser.add_argument("--output-root", default="generator/outputs")
|
||||
parser.add_argument("--use-gpu", action="store_true", help="Accepted for pipeline compatibility (GPU auto-detected).")
|
||||
return parser.parse_args(argv)
|
||||
|
||||
|
||||
def main(config_path, *, data_dir_override=None, output_root="generator/outputs"):
|
||||
import torch
|
||||
from src.data import GeneratorDataset, get_transform
|
||||
from src.models import get_model
|
||||
from src.training import train_dcgan
|
||||
from src.utils import load_config
|
||||
|
||||
cfg = load_config(config_path)
|
||||
|
||||
run_name = cfg["run_name"]
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
data_dir = data_dir_override or cfg.get("data_dir", "data")
|
||||
output_root = Path(output_root)
|
||||
models_dir = output_root / "models"
|
||||
logs_dir = output_root / "logs"
|
||||
|
||||
print(f"Run: {run_name}")
|
||||
print(f"Config: {cfg}")
|
||||
print(f"Device: {device} Data: {data_dir}")
|
||||
|
||||
model, kind = get_model(cfg)
|
||||
|
||||
augment = cfg.get("augment", True)
|
||||
transform = get_transform(cfg.get("image_size", 128), augment=augment)
|
||||
dataset = GeneratorDataset(
|
||||
data_dir,
|
||||
sources=cfg.get("sources", ["wiki"]),
|
||||
subsample=cfg.get("subsample", 1.0),
|
||||
transform=transform,
|
||||
)
|
||||
print(f"Dataset size: {len(dataset)}")
|
||||
|
||||
if kind == "dcgan":
|
||||
generator, discriminator = model
|
||||
history = train_dcgan(
|
||||
generator, discriminator, dataset, cfg,
|
||||
save_dir=models_dir, run_name=run_name, device=device,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"kind={kind!r} not yet implemented in this phase")
|
||||
|
||||
logs_dir.mkdir(parents=True, exist_ok=True)
|
||||
out = logs_dir / f"{run_name}.json"
|
||||
with open(out, "w") as f:
|
||||
json.dump({"run_name": run_name, "config": cfg, "history": history}, f, indent=2)
|
||||
print(f"\nSaved log to {out}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args(sys.argv[1:])
|
||||
main(args.config_path, data_dir_override=args.data_dir, output_root=args.output_root)
|
||||
Reference in New Issue
Block a user