""" Re-evaluate existing trained checkpoints with per-source metrics. Loads each config, rebuilds CV splits (deterministic), loads the _best.pt checkpoint per fold, runs predict_rows, and writes updated log files with aggregate + per-source + pairwise metrics. Usage: python tools/reevaluate.py # re-evaluate all experiments python tools/reevaluate.py p1_resnet18_baseline # specific experiments python tools/reevaluate.py --data-dir /mnt/data/DFF python tools/reevaluate.py --use-gpu """ import argparse import json import sys import warnings from pathlib import Path warnings.filterwarnings("ignore", message="Corrupt EXIF data", category=UserWarning) warnings.filterwarnings("ignore", message=".*weights_only.*", category=FutureWarning) # Ensure classifier/ is on sys.path so `src.*` imports work ROOT = Path(__file__).resolve().parent.parent if str(ROOT) not in sys.path: sys.path.insert(0, str(ROOT)) # Config paths (data_dir) are relative to the project root PROJECT_ROOT = ROOT.parent def parse_args(): parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( "run_names", nargs="*", help="Run names to re-evaluate (matches log filenames). Default: all.", ) parser.add_argument( "--data-dir", default=None, help="Override cfg['data_dir'] for this run.", ) parser.add_argument( "--output-root", default="outputs", help="Directory where models/logs live. Default: outputs", ) parser.add_argument( "--use-gpu", action="store_true", help="Use GPU for evaluation.", ) return parser.parse_args() # Map run_name -> config path (relative to classifier/) CONFIG_MAP = { # Phase 1 "p1_resnet18_baseline": "configs/phase1/p1_resnet18_baseline.json", "p1_simplecnn_baseline": "configs/phase1/p1_simplecnn_baseline.json", # Phase 2a – shortcut / holdout "p2a_t1_original": "configs/phase2/p2a_t1_original.json", "p2a_t2_real_norm": "configs/phase2/p2a_t2_real_norm.json", "p2a_t3_holdout_text2img": "configs/phase2/p2a_t3_holdout_text2img.json", "p2a_t3_holdout_inpainting": "configs/phase2/p2a_t3_holdout_inpainting.json", "p2a_t3_holdout_insight": "configs/phase2/p2a_t3_holdout_insight.json", # Phase 2b – resolution "p2b_resnet18_224": "configs/phase2/p2b_resnet18_224.json", "p2b_simplecnn_224": "configs/phase2/p2b_simplecnn_224.json", # Phase 2c – face crop "p2c_resnet18_facecrop": "configs/phase2/p2c_resnet18_facecrop.json", "p2c_simplecnn_facecrop": "configs/phase2/p2c_simplecnn_facecrop.json", # Phase 2d – augmentation "p2d_resnet18_aug": "configs/phase2/p2d_resnet18_aug.json", "p2d_simplecnn_aug": "configs/phase2/p2d_simplecnn_aug.json", # Phase 2e – face crop + aug "p2e_resnet18_facecrop_aug": "configs/phase2/p2e_resnet18_facecrop_aug.json", "p2e_simplecnn_facecrop_aug": "configs/phase2/p2e_simplecnn_facecrop_aug.json", } def main(): args = parse_args() import numpy as np import torch from src.data import DFFDataset, apply_subsample, build_transforms, get_splits from src.evaluation.evaluate import predict_rows from src.evaluation.metrics import calc_metrics, source_metrics, pair_metrics from src.models import get_model from src.utils import load_config from src.utils.cross_validation import aggregate_fold_metrics output_root = Path(args.output_root) logs_dir = output_root / "logs" models_dir = output_root / "models" # Determine which experiments to re-evaluate if args.run_names: run_names = args.run_names else: run_names = sorted( p.stem for p in logs_dir.glob("*.json") if p.stem in CONFIG_MAP ) device = "cuda" if args.use_gpu and torch.cuda.is_available() else "cpu" print(f"Device: {device}") for run_name in run_names: config_rel = CONFIG_MAP.get(run_name) if config_rel is None: print(f"\nSkipping {run_name}: no config mapping") continue config_path = ROOT / config_rel if not config_path.exists(): print(f"\nSkipping {run_name}: config not found ({config_rel})") continue # Check that at least one checkpoint exists checkpoints = sorted(models_dir.glob(f"{run_name}_fold*_best.pt")) if not checkpoints: print(f"\nSkipping {run_name}: no checkpoints found") continue print(f"\n{'='*60}") print(f"Re-evaluating: {run_name}") print(f" Config: {config_rel}") print(f" Checkpoints: {len(checkpoints)}") print(f"{'='*60}") # Load config cfg = load_config(config_path) seed = cfg.get("seed", 42) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) deterministic = cfg.get("deterministic", False) torch.backends.cudnn.deterministic = deterministic torch.backends.cudnn.benchmark = not deterministic data_dir = args.data_dir or cfg.get("data_dir", "data") # Config paths are relative to the project root, not classifier/ if not Path(data_dir).is_absolute(): data_dir = str(PROJECT_ROOT / data_dir) # Build dataset raw_ds = DFFDataset(data_dir, sources=cfg.get("dataset_sources")) sampled = apply_subsample(raw_ds, cfg) if sampled is not None: n_samples, total = sampled print(f" Subsampled to {n_samples}/{total} samples") # Build CV splits and transforms (deterministic – same as training) splits = get_splits(raw_ds, cfg) transform_builder = build_transforms(raw_ds, cfg, augment=cfg.get("augment")) n_folds = len(splits) fold_results = [] for fold_idx in range(n_folds): train_idx, val_idx, test_idx = splits[fold_idx] checkpoint_path = models_dir / f"{run_name}_fold{fold_idx}_best.pt" if not checkpoint_path.exists(): print(f" Fold {fold_idx}: checkpoint missing, skipping") continue # Rebuild model and load checkpoint model = get_model(cfg) model.load_state_dict( torch.load(checkpoint_path, map_location=device, weights_only=True) ) model.to(device).eval() # Build test dataset for this fold if cfg.get("normalization") == "real_norm": from src.preprocessing.pipeline import compute_real_stats norm_mean, norm_std = compute_real_stats(raw_ds, train_idx) else: norm_mean = norm_std = None test_dataset = transform_builder( test_idx, train=False, normalize_mean=norm_mean, normalize_std=norm_std, ) records = predict_rows( model, test_dataset, raw_ds, test_idx, cfg["batch_size"], device, num_workers=4, ) # Compute metrics test_metrics = calc_metrics(records) src_metrics = source_metrics(records) pairwise = pair_metrics(records) fold_result = { "fold": fold_idx, "train_size": len(train_idx), "val_size": len(val_idx), "test_size": len(test_idx), "test_metrics": test_metrics, "source_metrics": src_metrics, "pair_metrics": pairwise, } fold_results.append(fold_result) print(f" Fold {fold_idx}: auc={test_metrics.get('auc_roc', '?'):.4f} " f"acc={test_metrics.get('accuracy', '?'):.4f} " f"f1={test_metrics.get('f1', '?'):.4f}") for source, sm in sorted(src_metrics.items()): pa = sm.get("pairwise_auc") dr = sm.get("detection_rate") label = (f"pairwise_auc={pa:.4f}" if pa is not None else f"detection_rate={dr:.4f}" if dr is not None else "") print(f" {source}: {label}") if not fold_results: print(f" No folds evaluated for {run_name}") continue # Aggregate across folds test_metrics_list = [f["test_metrics"] for f in fold_results] aggregated = aggregate_fold_metrics(test_metrics_list) # Aggregate per-source metrics all_sources = sorted({s for f in fold_results for s in f["source_metrics"]}) aggregated_per_source = {} for source in all_sources: source_fold_metrics = [] for f in fold_results: sm = f["source_metrics"].get(source) if sm: source_fold_metrics.append({ k: v for k, v in sm.items() if isinstance(v, (int, float)) and k != "fold" }) if source_fold_metrics: aggregated_per_source[source] = aggregate_fold_metrics(source_fold_metrics) # Aggregate pairwise metrics all_pairs = sorted({p for f in fold_results for p in f["pair_metrics"]}) aggregated_pairwise = {} for pair in all_pairs: pair_fold_metrics = [] for f in fold_results: pm = f["pair_metrics"].get(pair) if pm: pair_fold_metrics.append({ k: v for k, v in pm.items() if isinstance(v, (int, float)) and k not in ("fold", "n") }) if pair_fold_metrics: aggregated_pairwise[pair] = aggregate_fold_metrics(pair_fold_metrics) # Load existing log to preserve training history log_path = logs_dir / f"{run_name}.json" if log_path.exists(): with open(log_path) as f: existing = json.load(f) # Keep the training history from the original log for fr_new in fold_results: for fr_old in existing.get("fold_results", []): if fr_old["fold"] == fr_new["fold"]: fr_new["history"] = fr_old.get("history") break else: existing = {} results = { "run_name": run_name, "n_folds": n_folds, "fold_results": fold_results, "aggregated_metrics": aggregated, "aggregated_per_source": aggregated_per_source, "aggregated_pairwise": aggregated_pairwise, "config": existing.get("config", cfg), } with open(log_path, "w") as f: json.dump(results, f, indent=2) print(f" Saved: {log_path}") print("\nDone.") if __name__ == "__main__": main()