""" Train a classifier with 5-fold stratified group cross-validation from a config file. Usage: python run.py configs/phase1/p1_simplecnn_baseline.json python run.py configs/phase1/p1_resnet18_baseline.json --data-dir /mnt/data/DFF --output-root /mnt/results """ import argparse import json import sys import warnings from pathlib import Path # PIL warns on corrupt EXIF metadata in some JPEGs — benign, not actionable. warnings.filterwarnings("ignore", message="Corrupt EXIF data", category=UserWarning) def parse_args(argv=None): parser = argparse.ArgumentParser(description=__doc__) parser.add_argument("config_path", help="Path to the JSON experiment config.") parser.add_argument("--data-dir", default=None, help="Override cfg['data_dir'] for this run.") parser.add_argument("--output-root", default="classifier/outputs", help="Directory where models/logs are written. Default: classifier/outputs") parser.add_argument("--use-gpu", action="store_true", help="Use GPU for training.") return parser.parse_args(argv) # ── Training entrypoint ───────────────────────────────────────────────────── def main(config_path, *, data_dir_override=None, output_root="classifier/outputs", use_gpu=False): import numpy as np import torch from src.models import get_model from src.data import DFFDataset, apply_subsample, build_transforms, get_splits from src.training import train_classifier_cv from src.utils import load_config # Load merged config (supports extends + shared defaults). cfg = load_config(config_path) # Set seeds and optional cuDNN determinism for reproducible runs. seed = cfg.get("seed", 42) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) deterministic = cfg.get("deterministic", False) torch.backends.cudnn.deterministic = deterministic torch.backends.cudnn.benchmark = not deterministic run_name = cfg["run_name"] device = "cuda" if use_gpu and torch.cuda.is_available() else "cpu" if use_gpu and not torch.cuda.is_available(): print("Warning: --use-gpu specified but CUDA not available, falling back to CPU") # Resolve runtime paths. data_dir = data_dir_override or cfg.get("data_dir", "data") output_root = Path(output_root) models_dir = output_root / "models" logs_dir = output_root / "logs" print(f"Run: {run_name}") print(f"Device: {device}") print(f"Data dir: {data_dir}") print(f"Output root: {output_root}") # Build raw dataset once, then derive fold-specific transformed subsets. raw_ds = DFFDataset(data_dir, sources=cfg.get("dataset_sources")) # Apply deterministic subsample (if configured) before split generation. sampled = apply_subsample(raw_ds, cfg) if sampled is not None: n_samples, total = sampled print(f"Subsampled to {n_samples}/{total} samples") # Create grouped CV folds and a transform builder callable for train/eval. splits = get_splits(raw_ds, cfg) transform_builder = build_transforms(raw_ds, cfg, augment=cfg.get("augment")) print(f"\nCV Split sizes:") for fold_idx, (train_idx, val_idx, test_idx) in enumerate(splits): print(f" Fold {fold_idx}: Train={len(train_idx)}, Val={len(val_idx)}, Test={len(test_idx)}") # Train across folds, save checkpoints, and collect aggregate metrics. logs_dir.mkdir(parents=True, exist_ok=True) results = train_classifier_cv( lambda: get_model(cfg), raw_ds, splits, epochs=cfg["epochs"], batch_size=cfg["batch_size"], lr=cfg["lr"], weight_decay=cfg.get("weight_decay", 1e-4), device=device, save_dir=models_dir, run_name=run_name, early_stopping_patience=cfg.get("early_stopping_patience", 0), num_workers=cfg.get("num_workers", 4), transform_builder=transform_builder, T_max=cfg.get("T_max", cfg["epochs"]), normalization=cfg.get("normalization"), logs_dir=logs_dir, ) # Persist metrics + config snapshot as the canonical run artifact. results["config"] = cfg out = logs_dir / f"{run_name}.json" with open(out, "w") as f: json.dump(results, f, indent=2) print(f"\nSaved results to {out}") if __name__ == "__main__": args = parse_args(sys.argv[1:]) main( args.config_path, data_dir_override=args.data_dir, output_root=args.output_root, use_gpu=args.use_gpu, )