Clean state
This commit is contained in:
@@ -0,0 +1,120 @@
|
||||
"""
|
||||
Train a classifier with 5-fold stratified group cross-validation from a config file.
|
||||
|
||||
Usage:
|
||||
python run.py configs/phase1/p1_simplecnn_baseline.json
|
||||
python run.py configs/phase1/p1_resnet18_baseline.json --data-dir /mnt/data/DFF --output-root /mnt/results
|
||||
"""
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
# PIL warns on corrupt EXIF metadata in some JPEGs — benign, not actionable.
|
||||
warnings.filterwarnings("ignore", message="Corrupt EXIF data", category=UserWarning)
|
||||
|
||||
|
||||
def parse_args(argv=None):
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument("config_path", help="Path to the JSON experiment config.")
|
||||
parser.add_argument("--data-dir", default=None, help="Override cfg['data_dir'] for this run.")
|
||||
parser.add_argument("--output-root", default="classifier/outputs", help="Directory where models/logs are written. Default: classifier/outputs")
|
||||
parser.add_argument("--use-gpu", action="store_true", help="Use GPU for training.")
|
||||
return parser.parse_args(argv)
|
||||
|
||||
|
||||
# ── Training entrypoint ─────────────────────────────────────────────────────
|
||||
|
||||
def main(config_path, *, data_dir_override=None, output_root="classifier/outputs", use_gpu=False):
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from src.models import get_model
|
||||
from src.data import DFFDataset, apply_subsample, build_transforms, get_splits
|
||||
from src.training import train_classifier_cv
|
||||
from src.utils import load_config
|
||||
|
||||
# Load merged config (supports extends + shared defaults).
|
||||
cfg = load_config(config_path)
|
||||
|
||||
# Set seeds and optional cuDNN determinism for reproducible runs.
|
||||
seed = cfg.get("seed", 42)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
deterministic = cfg.get("deterministic", False)
|
||||
torch.backends.cudnn.deterministic = deterministic
|
||||
torch.backends.cudnn.benchmark = not deterministic
|
||||
|
||||
run_name = cfg["run_name"]
|
||||
device = "cuda" if use_gpu and torch.cuda.is_available() else "cpu"
|
||||
if use_gpu and not torch.cuda.is_available():
|
||||
print("Warning: --use-gpu specified but CUDA not available, falling back to CPU")
|
||||
|
||||
# Resolve runtime paths.
|
||||
data_dir = data_dir_override or cfg.get("data_dir", "data")
|
||||
output_root = Path(output_root)
|
||||
models_dir = output_root / "models"
|
||||
logs_dir = output_root / "logs"
|
||||
|
||||
print(f"Run: {run_name}")
|
||||
print(f"Device: {device}")
|
||||
print(f"Data dir: {data_dir}")
|
||||
print(f"Output root: {output_root}")
|
||||
|
||||
# Build raw dataset once, then derive fold-specific transformed subsets.
|
||||
raw_ds = DFFDataset(data_dir, sources=cfg.get("dataset_sources"))
|
||||
|
||||
# Apply deterministic subsample (if configured) before split generation.
|
||||
sampled = apply_subsample(raw_ds, cfg)
|
||||
if sampled is not None:
|
||||
n_samples, total = sampled
|
||||
print(f"Subsampled to {n_samples}/{total} samples")
|
||||
|
||||
# Create grouped CV folds and a transform builder callable for train/eval.
|
||||
splits = get_splits(raw_ds, cfg)
|
||||
transform_builder = build_transforms(raw_ds, cfg, augment=cfg.get("augment"))
|
||||
|
||||
print(f"\nCV Split sizes:")
|
||||
for fold_idx, (train_idx, val_idx, test_idx) in enumerate(splits):
|
||||
print(f" Fold {fold_idx}: Train={len(train_idx)}, Val={len(val_idx)}, Test={len(test_idx)}")
|
||||
|
||||
# Train across folds, save checkpoints, and collect aggregate metrics.
|
||||
logs_dir.mkdir(parents=True, exist_ok=True)
|
||||
results = train_classifier_cv(
|
||||
lambda: get_model(cfg),
|
||||
raw_ds,
|
||||
splits,
|
||||
epochs=cfg["epochs"],
|
||||
batch_size=cfg["batch_size"],
|
||||
lr=cfg["lr"],
|
||||
weight_decay=cfg.get("weight_decay", 1e-4),
|
||||
device=device,
|
||||
save_dir=models_dir,
|
||||
run_name=run_name,
|
||||
early_stopping_patience=cfg.get("early_stopping_patience", 0),
|
||||
num_workers=cfg.get("num_workers", 4),
|
||||
transform_builder=transform_builder,
|
||||
T_max=cfg.get("T_max", cfg["epochs"]),
|
||||
normalization=cfg.get("normalization"),
|
||||
logs_dir=logs_dir,
|
||||
)
|
||||
|
||||
# Persist metrics + config snapshot as the canonical run artifact.
|
||||
results["config"] = cfg
|
||||
out = logs_dir / f"{run_name}.json"
|
||||
with open(out, "w") as f:
|
||||
json.dump(results, f, indent=2)
|
||||
print(f"\nSaved results to {out}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args(sys.argv[1:])
|
||||
main(
|
||||
args.config_path,
|
||||
data_dir_override=args.data_dir,
|
||||
output_root=args.output_root,
|
||||
use_gpu=args.use_gpu,
|
||||
)
|
||||
Reference in New Issue
Block a user