Files
DRL_PROJ/classifier/tools/reevaluate.py
T
Johnny Fernandes bb3dfb92d5 Clean state
2026-04-30 01:25:39 +01:00

289 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Re-evaluate existing trained checkpoints with per-source metrics.
Loads each config, rebuilds CV splits (deterministic), loads the _best.pt
checkpoint per fold, runs predict_rows, and writes updated log files
with aggregate + per-source + pairwise metrics.
Usage:
python tools/reevaluate.py # re-evaluate all experiments
python tools/reevaluate.py p1_resnet18_baseline # specific experiments
python tools/reevaluate.py --data-dir /mnt/data/DFF
python tools/reevaluate.py --use-gpu
"""
import argparse
import json
import sys
import warnings
from pathlib import Path
warnings.filterwarnings("ignore", message="Corrupt EXIF data", category=UserWarning)
warnings.filterwarnings("ignore", message=".*weights_only.*", category=FutureWarning)
# Ensure classifier/ is on sys.path so `src.*` imports work
ROOT = Path(__file__).resolve().parent.parent
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
# Config paths (data_dir) are relative to the project root
PROJECT_ROOT = ROOT.parent
def parse_args():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"run_names", nargs="*",
help="Run names to re-evaluate (matches log filenames). Default: all.",
)
parser.add_argument(
"--data-dir", default=None,
help="Override cfg['data_dir'] for this run.",
)
parser.add_argument(
"--output-root", default="outputs",
help="Directory where models/logs live. Default: outputs",
)
parser.add_argument(
"--use-gpu", action="store_true",
help="Use GPU for evaluation.",
)
return parser.parse_args()
# Map run_name -> config path (relative to classifier/)
CONFIG_MAP = {
# Phase 1
"p1_resnet18_baseline": "configs/phase1/p1_resnet18_baseline.json",
"p1_simplecnn_baseline": "configs/phase1/p1_simplecnn_baseline.json",
# Phase 2a shortcut / holdout
"p2a_t1_original": "configs/phase2/p2a_t1_original.json",
"p2a_t2_real_norm": "configs/phase2/p2a_t2_real_norm.json",
"p2a_t3_holdout_text2img": "configs/phase2/p2a_t3_holdout_text2img.json",
"p2a_t3_holdout_inpainting": "configs/phase2/p2a_t3_holdout_inpainting.json",
"p2a_t3_holdout_insight": "configs/phase2/p2a_t3_holdout_insight.json",
# Phase 2b resolution
"p2b_resnet18_224": "configs/phase2/p2b_resnet18_224.json",
"p2b_simplecnn_224": "configs/phase2/p2b_simplecnn_224.json",
# Phase 2c face crop
"p2c_resnet18_facecrop": "configs/phase2/p2c_resnet18_facecrop.json",
"p2c_simplecnn_facecrop": "configs/phase2/p2c_simplecnn_facecrop.json",
# Phase 2d augmentation
"p2d_resnet18_aug": "configs/phase2/p2d_resnet18_aug.json",
"p2d_simplecnn_aug": "configs/phase2/p2d_simplecnn_aug.json",
# Phase 2e face crop + aug
"p2e_resnet18_facecrop_aug": "configs/phase2/p2e_resnet18_facecrop_aug.json",
"p2e_simplecnn_facecrop_aug": "configs/phase2/p2e_simplecnn_facecrop_aug.json",
}
def main():
args = parse_args()
import numpy as np
import torch
from src.data import DFFDataset, apply_subsample, build_transforms, get_splits
from src.evaluation.evaluate import predict_rows
from src.evaluation.metrics import calc_metrics, source_metrics, pair_metrics
from src.models import get_model
from src.utils import load_config
from src.utils.cross_validation import aggregate_fold_metrics
output_root = Path(args.output_root)
logs_dir = output_root / "logs"
models_dir = output_root / "models"
# Determine which experiments to re-evaluate
if args.run_names:
run_names = args.run_names
else:
run_names = sorted(
p.stem for p in logs_dir.glob("*.json")
if p.stem in CONFIG_MAP
)
device = "cuda" if args.use_gpu and torch.cuda.is_available() else "cpu"
print(f"Device: {device}")
for run_name in run_names:
config_rel = CONFIG_MAP.get(run_name)
if config_rel is None:
print(f"\nSkipping {run_name}: no config mapping")
continue
config_path = ROOT / config_rel
if not config_path.exists():
print(f"\nSkipping {run_name}: config not found ({config_rel})")
continue
# Check that at least one checkpoint exists
checkpoints = sorted(models_dir.glob(f"{run_name}_fold*_best.pt"))
if not checkpoints:
print(f"\nSkipping {run_name}: no checkpoints found")
continue
print(f"\n{'='*60}")
print(f"Re-evaluating: {run_name}")
print(f" Config: {config_rel}")
print(f" Checkpoints: {len(checkpoints)}")
print(f"{'='*60}")
# Load config
cfg = load_config(config_path)
seed = cfg.get("seed", 42)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
deterministic = cfg.get("deterministic", False)
torch.backends.cudnn.deterministic = deterministic
torch.backends.cudnn.benchmark = not deterministic
data_dir = args.data_dir or cfg.get("data_dir", "data")
# Config paths are relative to the project root, not classifier/
if not Path(data_dir).is_absolute():
data_dir = str(PROJECT_ROOT / data_dir)
# Build dataset
raw_ds = DFFDataset(data_dir, sources=cfg.get("dataset_sources"))
sampled = apply_subsample(raw_ds, cfg)
if sampled is not None:
n_samples, total = sampled
print(f" Subsampled to {n_samples}/{total} samples")
# Build CV splits and transforms (deterministic same as training)
splits = get_splits(raw_ds, cfg)
transform_builder = build_transforms(raw_ds, cfg, augment=cfg.get("augment"))
n_folds = len(splits)
fold_results = []
for fold_idx in range(n_folds):
train_idx, val_idx, test_idx = splits[fold_idx]
checkpoint_path = models_dir / f"{run_name}_fold{fold_idx}_best.pt"
if not checkpoint_path.exists():
print(f" Fold {fold_idx}: checkpoint missing, skipping")
continue
# Rebuild model and load checkpoint
model = get_model(cfg)
model.load_state_dict(
torch.load(checkpoint_path, map_location=device, weights_only=True)
)
model.to(device).eval()
# Build test dataset for this fold
if cfg.get("normalization") == "real_norm":
from src.preprocessing.pipeline import compute_real_stats
norm_mean, norm_std = compute_real_stats(raw_ds, train_idx)
else:
norm_mean = norm_std = None
test_dataset = transform_builder(
test_idx, train=False,
normalize_mean=norm_mean, normalize_std=norm_std,
)
records = predict_rows(
model, test_dataset, raw_ds, test_idx,
cfg["batch_size"], device, num_workers=4,
)
# Compute metrics
test_metrics = calc_metrics(records)
src_metrics = source_metrics(records)
pairwise = pair_metrics(records)
fold_result = {
"fold": fold_idx,
"train_size": len(train_idx),
"val_size": len(val_idx),
"test_size": len(test_idx),
"test_metrics": test_metrics,
"source_metrics": src_metrics,
"pair_metrics": pairwise,
}
fold_results.append(fold_result)
print(f" Fold {fold_idx}: auc={test_metrics.get('auc_roc', '?'):.4f} "
f"acc={test_metrics.get('accuracy', '?'):.4f} "
f"f1={test_metrics.get('f1', '?'):.4f}")
for source, sm in sorted(src_metrics.items()):
pa = sm.get("pairwise_auc")
dr = sm.get("detection_rate")
label = (f"pairwise_auc={pa:.4f}" if pa is not None
else f"detection_rate={dr:.4f}" if dr is not None else "")
print(f" {source}: {label}")
if not fold_results:
print(f" No folds evaluated for {run_name}")
continue
# Aggregate across folds
test_metrics_list = [f["test_metrics"] for f in fold_results]
aggregated = aggregate_fold_metrics(test_metrics_list)
# Aggregate per-source metrics
all_sources = sorted({s for f in fold_results for s in f["source_metrics"]})
aggregated_per_source = {}
for source in all_sources:
source_fold_metrics = []
for f in fold_results:
sm = f["source_metrics"].get(source)
if sm:
source_fold_metrics.append({
k: v for k, v in sm.items()
if isinstance(v, (int, float)) and k != "fold"
})
if source_fold_metrics:
aggregated_per_source[source] = aggregate_fold_metrics(source_fold_metrics)
# Aggregate pairwise metrics
all_pairs = sorted({p for f in fold_results for p in f["pair_metrics"]})
aggregated_pairwise = {}
for pair in all_pairs:
pair_fold_metrics = []
for f in fold_results:
pm = f["pair_metrics"].get(pair)
if pm:
pair_fold_metrics.append({
k: v for k, v in pm.items()
if isinstance(v, (int, float)) and k not in ("fold", "n")
})
if pair_fold_metrics:
aggregated_pairwise[pair] = aggregate_fold_metrics(pair_fold_metrics)
# Load existing log to preserve training history
log_path = logs_dir / f"{run_name}.json"
if log_path.exists():
with open(log_path) as f:
existing = json.load(f)
# Keep the training history from the original log
for fr_new in fold_results:
for fr_old in existing.get("fold_results", []):
if fr_old["fold"] == fr_new["fold"]:
fr_new["history"] = fr_old.get("history")
break
else:
existing = {}
results = {
"run_name": run_name,
"n_folds": n_folds,
"fold_results": fold_results,
"aggregated_metrics": aggregated,
"aggregated_per_source": aggregated_per_source,
"aggregated_pairwise": aggregated_pairwise,
"config": existing.get("config", cfg),
}
with open(log_path, "w") as f:
json.dump(results, f, indent=2)
print(f" Saved: {log_path}")
print("\nDone.")
if __name__ == "__main__":
main()