Files
DRL_PROJ/classifier/run.py
T
Johnny Fernandes bb3dfb92d5 Clean state
2026-04-30 01:25:39 +01:00

121 lines
4.6 KiB
Python

"""
Train a classifier with 5-fold stratified group cross-validation from a config file.
Usage:
python run.py configs/phase1/p1_simplecnn_baseline.json
python run.py configs/phase1/p1_resnet18_baseline.json --data-dir /mnt/data/DFF --output-root /mnt/results
"""
import argparse
import json
import sys
import warnings
from pathlib import Path
# PIL warns on corrupt EXIF metadata in some JPEGs — benign, not actionable.
warnings.filterwarnings("ignore", message="Corrupt EXIF data", category=UserWarning)
def parse_args(argv=None):
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("config_path", help="Path to the JSON experiment config.")
parser.add_argument("--data-dir", default=None, help="Override cfg['data_dir'] for this run.")
parser.add_argument("--output-root", default="classifier/outputs", help="Directory where models/logs are written. Default: classifier/outputs")
parser.add_argument("--use-gpu", action="store_true", help="Use GPU for training.")
return parser.parse_args(argv)
# ── Training entrypoint ─────────────────────────────────────────────────────
def main(config_path, *, data_dir_override=None, output_root="classifier/outputs", use_gpu=False):
import numpy as np
import torch
from src.models import get_model
from src.data import DFFDataset, apply_subsample, build_transforms, get_splits
from src.training import train_classifier_cv
from src.utils import load_config
# Load merged config (supports extends + shared defaults).
cfg = load_config(config_path)
# Set seeds and optional cuDNN determinism for reproducible runs.
seed = cfg.get("seed", 42)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
deterministic = cfg.get("deterministic", False)
torch.backends.cudnn.deterministic = deterministic
torch.backends.cudnn.benchmark = not deterministic
run_name = cfg["run_name"]
device = "cuda" if use_gpu and torch.cuda.is_available() else "cpu"
if use_gpu and not torch.cuda.is_available():
print("Warning: --use-gpu specified but CUDA not available, falling back to CPU")
# Resolve runtime paths.
data_dir = data_dir_override or cfg.get("data_dir", "data")
output_root = Path(output_root)
models_dir = output_root / "models"
logs_dir = output_root / "logs"
print(f"Run: {run_name}")
print(f"Device: {device}")
print(f"Data dir: {data_dir}")
print(f"Output root: {output_root}")
# Build raw dataset once, then derive fold-specific transformed subsets.
raw_ds = DFFDataset(data_dir, sources=cfg.get("dataset_sources"))
# Apply deterministic subsample (if configured) before split generation.
sampled = apply_subsample(raw_ds, cfg)
if sampled is not None:
n_samples, total = sampled
print(f"Subsampled to {n_samples}/{total} samples")
# Create grouped CV folds and a transform builder callable for train/eval.
splits = get_splits(raw_ds, cfg)
transform_builder = build_transforms(raw_ds, cfg, augment=cfg.get("augment"))
print(f"\nCV Split sizes:")
for fold_idx, (train_idx, val_idx, test_idx) in enumerate(splits):
print(f" Fold {fold_idx}: Train={len(train_idx)}, Val={len(val_idx)}, Test={len(test_idx)}")
# Train across folds, save checkpoints, and collect aggregate metrics.
logs_dir.mkdir(parents=True, exist_ok=True)
results = train_classifier_cv(
lambda: get_model(cfg),
raw_ds,
splits,
epochs=cfg["epochs"],
batch_size=cfg["batch_size"],
lr=cfg["lr"],
weight_decay=cfg.get("weight_decay", 1e-4),
device=device,
save_dir=models_dir,
run_name=run_name,
early_stopping_patience=cfg.get("early_stopping_patience", 0),
num_workers=cfg.get("num_workers", 4),
transform_builder=transform_builder,
T_max=cfg.get("T_max", cfg["epochs"]),
normalization=cfg.get("normalization"),
logs_dir=logs_dir,
)
# Persist metrics + config snapshot as the canonical run artifact.
results["config"] = cfg
out = logs_dir / f"{run_name}.json"
with open(out, "w") as f:
json.dump(results, f, indent=2)
print(f"\nSaved results to {out}")
if __name__ == "__main__":
args = parse_args(sys.argv[1:])
main(
args.config_path,
data_dir_override=args.data_dir,
output_root=args.output_root,
use_gpu=args.use_gpu,
)