121 lines
4.6 KiB
Python
121 lines
4.6 KiB
Python
"""
|
|
Train a classifier with 5-fold stratified group cross-validation from a config file.
|
|
|
|
Usage:
|
|
python run.py configs/phase1/p1_simplecnn_baseline.json
|
|
python run.py configs/phase1/p1_resnet18_baseline.json --data-dir /mnt/data/DFF --output-root /mnt/results
|
|
"""
|
|
import argparse
|
|
import json
|
|
import sys
|
|
import warnings
|
|
from pathlib import Path
|
|
|
|
# PIL warns on corrupt EXIF metadata in some JPEGs — benign, not actionable.
|
|
warnings.filterwarnings("ignore", message="Corrupt EXIF data", category=UserWarning)
|
|
|
|
|
|
def parse_args(argv=None):
|
|
parser = argparse.ArgumentParser(description=__doc__)
|
|
parser.add_argument("config_path", help="Path to the JSON experiment config.")
|
|
parser.add_argument("--data-dir", default=None, help="Override cfg['data_dir'] for this run.")
|
|
parser.add_argument("--output-root", default="classifier/outputs", help="Directory where models/logs are written. Default: classifier/outputs")
|
|
parser.add_argument("--use-gpu", action="store_true", help="Use GPU for training.")
|
|
return parser.parse_args(argv)
|
|
|
|
|
|
# ── Training entrypoint ─────────────────────────────────────────────────────
|
|
|
|
def main(config_path, *, data_dir_override=None, output_root="classifier/outputs", use_gpu=False):
|
|
import numpy as np
|
|
import torch
|
|
|
|
from src.models import get_model
|
|
from src.data import DFFDataset, apply_subsample, build_transforms, get_splits
|
|
from src.training import train_classifier_cv
|
|
from src.utils import load_config
|
|
|
|
# Load merged config (supports extends + shared defaults).
|
|
cfg = load_config(config_path)
|
|
|
|
# Set seeds and optional cuDNN determinism for reproducible runs.
|
|
seed = cfg.get("seed", 42)
|
|
np.random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
if torch.cuda.is_available():
|
|
torch.cuda.manual_seed_all(seed)
|
|
deterministic = cfg.get("deterministic", False)
|
|
torch.backends.cudnn.deterministic = deterministic
|
|
torch.backends.cudnn.benchmark = not deterministic
|
|
|
|
run_name = cfg["run_name"]
|
|
device = "cuda" if use_gpu and torch.cuda.is_available() else "cpu"
|
|
if use_gpu and not torch.cuda.is_available():
|
|
print("Warning: --use-gpu specified but CUDA not available, falling back to CPU")
|
|
|
|
# Resolve runtime paths.
|
|
data_dir = data_dir_override or cfg.get("data_dir", "data")
|
|
output_root = Path(output_root)
|
|
models_dir = output_root / "models"
|
|
logs_dir = output_root / "logs"
|
|
|
|
print(f"Run: {run_name}")
|
|
print(f"Device: {device}")
|
|
print(f"Data dir: {data_dir}")
|
|
print(f"Output root: {output_root}")
|
|
|
|
# Build raw dataset once, then derive fold-specific transformed subsets.
|
|
raw_ds = DFFDataset(data_dir, sources=cfg.get("dataset_sources"))
|
|
|
|
# Apply deterministic subsample (if configured) before split generation.
|
|
sampled = apply_subsample(raw_ds, cfg)
|
|
if sampled is not None:
|
|
n_samples, total = sampled
|
|
print(f"Subsampled to {n_samples}/{total} samples")
|
|
|
|
# Create grouped CV folds and a transform builder callable for train/eval.
|
|
splits = get_splits(raw_ds, cfg)
|
|
transform_builder = build_transforms(raw_ds, cfg, augment=cfg.get("augment"))
|
|
|
|
print(f"\nCV Split sizes:")
|
|
for fold_idx, (train_idx, val_idx, test_idx) in enumerate(splits):
|
|
print(f" Fold {fold_idx}: Train={len(train_idx)}, Val={len(val_idx)}, Test={len(test_idx)}")
|
|
|
|
# Train across folds, save checkpoints, and collect aggregate metrics.
|
|
logs_dir.mkdir(parents=True, exist_ok=True)
|
|
results = train_classifier_cv(
|
|
lambda: get_model(cfg),
|
|
raw_ds,
|
|
splits,
|
|
epochs=cfg["epochs"],
|
|
batch_size=cfg["batch_size"],
|
|
lr=cfg["lr"],
|
|
weight_decay=cfg.get("weight_decay", 1e-4),
|
|
device=device,
|
|
save_dir=models_dir,
|
|
run_name=run_name,
|
|
early_stopping_patience=cfg.get("early_stopping_patience", 0),
|
|
num_workers=cfg.get("num_workers", 4),
|
|
transform_builder=transform_builder,
|
|
T_max=cfg.get("T_max", cfg["epochs"]),
|
|
normalization=cfg.get("normalization"),
|
|
logs_dir=logs_dir,
|
|
)
|
|
|
|
# Persist metrics + config snapshot as the canonical run artifact.
|
|
results["config"] = cfg
|
|
out = logs_dir / f"{run_name}.json"
|
|
with open(out, "w") as f:
|
|
json.dump(results, f, indent=2)
|
|
print(f"\nSaved results to {out}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
args = parse_args(sys.argv[1:])
|
|
main(
|
|
args.config_path,
|
|
data_dir_override=args.data_dir,
|
|
output_root=args.output_root,
|
|
use_gpu=args.use_gpu,
|
|
)
|