289 lines
11 KiB
Python
289 lines
11 KiB
Python
"""
|
||
Re-evaluate existing trained checkpoints with per-source metrics.
|
||
|
||
Loads each config, rebuilds CV splits (deterministic), loads the _best.pt
|
||
checkpoint per fold, runs predict_rows, and writes updated log files
|
||
with aggregate + per-source + pairwise metrics.
|
||
|
||
Usage:
|
||
python tools/reevaluate.py # re-evaluate all experiments
|
||
python tools/reevaluate.py p1_resnet18_baseline # specific experiments
|
||
python tools/reevaluate.py --data-dir /mnt/data/DFF
|
||
python tools/reevaluate.py --use-gpu
|
||
"""
|
||
import argparse
|
||
import json
|
||
import sys
|
||
import warnings
|
||
from pathlib import Path
|
||
|
||
warnings.filterwarnings("ignore", message="Corrupt EXIF data", category=UserWarning)
|
||
warnings.filterwarnings("ignore", message=".*weights_only.*", category=FutureWarning)
|
||
|
||
# Ensure classifier/ is on sys.path so `src.*` imports work
|
||
ROOT = Path(__file__).resolve().parent.parent
|
||
if str(ROOT) not in sys.path:
|
||
sys.path.insert(0, str(ROOT))
|
||
|
||
# Config paths (data_dir) are relative to the project root
|
||
PROJECT_ROOT = ROOT.parent
|
||
|
||
|
||
def parse_args():
|
||
parser = argparse.ArgumentParser(description=__doc__)
|
||
parser.add_argument(
|
||
"run_names", nargs="*",
|
||
help="Run names to re-evaluate (matches log filenames). Default: all.",
|
||
)
|
||
parser.add_argument(
|
||
"--data-dir", default=None,
|
||
help="Override cfg['data_dir'] for this run.",
|
||
)
|
||
parser.add_argument(
|
||
"--output-root", default="outputs",
|
||
help="Directory where models/logs live. Default: outputs",
|
||
)
|
||
parser.add_argument(
|
||
"--use-gpu", action="store_true",
|
||
help="Use GPU for evaluation.",
|
||
)
|
||
return parser.parse_args()
|
||
|
||
|
||
# Map run_name -> config path (relative to classifier/)
|
||
CONFIG_MAP = {
|
||
# Phase 1
|
||
"p1_resnet18_baseline": "configs/phase1/p1_resnet18_baseline.json",
|
||
"p1_simplecnn_baseline": "configs/phase1/p1_simplecnn_baseline.json",
|
||
# Phase 2a – shortcut / holdout
|
||
"p2a_t1_original": "configs/phase2/p2a_t1_original.json",
|
||
"p2a_t2_real_norm": "configs/phase2/p2a_t2_real_norm.json",
|
||
"p2a_t3_holdout_text2img": "configs/phase2/p2a_t3_holdout_text2img.json",
|
||
"p2a_t3_holdout_inpainting": "configs/phase2/p2a_t3_holdout_inpainting.json",
|
||
"p2a_t3_holdout_insight": "configs/phase2/p2a_t3_holdout_insight.json",
|
||
# Phase 2b – resolution
|
||
"p2b_resnet18_224": "configs/phase2/p2b_resnet18_224.json",
|
||
"p2b_simplecnn_224": "configs/phase2/p2b_simplecnn_224.json",
|
||
# Phase 2c – face crop
|
||
"p2c_resnet18_facecrop": "configs/phase2/p2c_resnet18_facecrop.json",
|
||
"p2c_simplecnn_facecrop": "configs/phase2/p2c_simplecnn_facecrop.json",
|
||
# Phase 2d – augmentation
|
||
"p2d_resnet18_aug": "configs/phase2/p2d_resnet18_aug.json",
|
||
"p2d_simplecnn_aug": "configs/phase2/p2d_simplecnn_aug.json",
|
||
# Phase 2e – face crop + aug
|
||
"p2e_resnet18_facecrop_aug": "configs/phase2/p2e_resnet18_facecrop_aug.json",
|
||
"p2e_simplecnn_facecrop_aug": "configs/phase2/p2e_simplecnn_facecrop_aug.json",
|
||
}
|
||
|
||
|
||
def main():
|
||
args = parse_args()
|
||
import numpy as np
|
||
import torch
|
||
from src.data import DFFDataset, apply_subsample, build_transforms, get_splits
|
||
from src.evaluation.evaluate import predict_rows
|
||
from src.evaluation.metrics import calc_metrics, source_metrics, pair_metrics
|
||
from src.models import get_model
|
||
from src.utils import load_config
|
||
from src.utils.cross_validation import aggregate_fold_metrics
|
||
|
||
output_root = Path(args.output_root)
|
||
logs_dir = output_root / "logs"
|
||
models_dir = output_root / "models"
|
||
|
||
# Determine which experiments to re-evaluate
|
||
if args.run_names:
|
||
run_names = args.run_names
|
||
else:
|
||
run_names = sorted(
|
||
p.stem for p in logs_dir.glob("*.json")
|
||
if p.stem in CONFIG_MAP
|
||
)
|
||
|
||
device = "cuda" if args.use_gpu and torch.cuda.is_available() else "cpu"
|
||
print(f"Device: {device}")
|
||
|
||
for run_name in run_names:
|
||
config_rel = CONFIG_MAP.get(run_name)
|
||
if config_rel is None:
|
||
print(f"\nSkipping {run_name}: no config mapping")
|
||
continue
|
||
|
||
config_path = ROOT / config_rel
|
||
if not config_path.exists():
|
||
print(f"\nSkipping {run_name}: config not found ({config_rel})")
|
||
continue
|
||
|
||
# Check that at least one checkpoint exists
|
||
checkpoints = sorted(models_dir.glob(f"{run_name}_fold*_best.pt"))
|
||
if not checkpoints:
|
||
print(f"\nSkipping {run_name}: no checkpoints found")
|
||
continue
|
||
|
||
print(f"\n{'='*60}")
|
||
print(f"Re-evaluating: {run_name}")
|
||
print(f" Config: {config_rel}")
|
||
print(f" Checkpoints: {len(checkpoints)}")
|
||
print(f"{'='*60}")
|
||
|
||
# Load config
|
||
cfg = load_config(config_path)
|
||
seed = cfg.get("seed", 42)
|
||
np.random.seed(seed)
|
||
torch.manual_seed(seed)
|
||
if torch.cuda.is_available():
|
||
torch.cuda.manual_seed_all(seed)
|
||
deterministic = cfg.get("deterministic", False)
|
||
torch.backends.cudnn.deterministic = deterministic
|
||
torch.backends.cudnn.benchmark = not deterministic
|
||
|
||
data_dir = args.data_dir or cfg.get("data_dir", "data")
|
||
# Config paths are relative to the project root, not classifier/
|
||
if not Path(data_dir).is_absolute():
|
||
data_dir = str(PROJECT_ROOT / data_dir)
|
||
|
||
# Build dataset
|
||
raw_ds = DFFDataset(data_dir, sources=cfg.get("dataset_sources"))
|
||
|
||
sampled = apply_subsample(raw_ds, cfg)
|
||
if sampled is not None:
|
||
n_samples, total = sampled
|
||
print(f" Subsampled to {n_samples}/{total} samples")
|
||
|
||
# Build CV splits and transforms (deterministic – same as training)
|
||
splits = get_splits(raw_ds, cfg)
|
||
transform_builder = build_transforms(raw_ds, cfg, augment=cfg.get("augment"))
|
||
|
||
n_folds = len(splits)
|
||
fold_results = []
|
||
|
||
for fold_idx in range(n_folds):
|
||
train_idx, val_idx, test_idx = splits[fold_idx]
|
||
|
||
checkpoint_path = models_dir / f"{run_name}_fold{fold_idx}_best.pt"
|
||
if not checkpoint_path.exists():
|
||
print(f" Fold {fold_idx}: checkpoint missing, skipping")
|
||
continue
|
||
|
||
# Rebuild model and load checkpoint
|
||
model = get_model(cfg)
|
||
model.load_state_dict(
|
||
torch.load(checkpoint_path, map_location=device, weights_only=True)
|
||
)
|
||
model.to(device).eval()
|
||
|
||
# Build test dataset for this fold
|
||
if cfg.get("normalization") == "real_norm":
|
||
from src.preprocessing.pipeline import compute_real_stats
|
||
norm_mean, norm_std = compute_real_stats(raw_ds, train_idx)
|
||
else:
|
||
norm_mean = norm_std = None
|
||
|
||
test_dataset = transform_builder(
|
||
test_idx, train=False,
|
||
normalize_mean=norm_mean, normalize_std=norm_std,
|
||
)
|
||
|
||
records = predict_rows(
|
||
model, test_dataset, raw_ds, test_idx,
|
||
cfg["batch_size"], device, num_workers=4,
|
||
)
|
||
|
||
# Compute metrics
|
||
test_metrics = calc_metrics(records)
|
||
src_metrics = source_metrics(records)
|
||
pairwise = pair_metrics(records)
|
||
|
||
fold_result = {
|
||
"fold": fold_idx,
|
||
"train_size": len(train_idx),
|
||
"val_size": len(val_idx),
|
||
"test_size": len(test_idx),
|
||
"test_metrics": test_metrics,
|
||
"source_metrics": src_metrics,
|
||
"pair_metrics": pairwise,
|
||
}
|
||
fold_results.append(fold_result)
|
||
|
||
print(f" Fold {fold_idx}: auc={test_metrics.get('auc_roc', '?'):.4f} "
|
||
f"acc={test_metrics.get('accuracy', '?'):.4f} "
|
||
f"f1={test_metrics.get('f1', '?'):.4f}")
|
||
for source, sm in sorted(src_metrics.items()):
|
||
pa = sm.get("pairwise_auc")
|
||
dr = sm.get("detection_rate")
|
||
label = (f"pairwise_auc={pa:.4f}" if pa is not None
|
||
else f"detection_rate={dr:.4f}" if dr is not None else "")
|
||
print(f" {source}: {label}")
|
||
|
||
if not fold_results:
|
||
print(f" No folds evaluated for {run_name}")
|
||
continue
|
||
|
||
# Aggregate across folds
|
||
test_metrics_list = [f["test_metrics"] for f in fold_results]
|
||
aggregated = aggregate_fold_metrics(test_metrics_list)
|
||
|
||
# Aggregate per-source metrics
|
||
all_sources = sorted({s for f in fold_results for s in f["source_metrics"]})
|
||
aggregated_per_source = {}
|
||
for source in all_sources:
|
||
source_fold_metrics = []
|
||
for f in fold_results:
|
||
sm = f["source_metrics"].get(source)
|
||
if sm:
|
||
source_fold_metrics.append({
|
||
k: v for k, v in sm.items()
|
||
if isinstance(v, (int, float)) and k != "fold"
|
||
})
|
||
if source_fold_metrics:
|
||
aggregated_per_source[source] = aggregate_fold_metrics(source_fold_metrics)
|
||
|
||
# Aggregate pairwise metrics
|
||
all_pairs = sorted({p for f in fold_results for p in f["pair_metrics"]})
|
||
aggregated_pairwise = {}
|
||
for pair in all_pairs:
|
||
pair_fold_metrics = []
|
||
for f in fold_results:
|
||
pm = f["pair_metrics"].get(pair)
|
||
if pm:
|
||
pair_fold_metrics.append({
|
||
k: v for k, v in pm.items()
|
||
if isinstance(v, (int, float)) and k not in ("fold", "n")
|
||
})
|
||
if pair_fold_metrics:
|
||
aggregated_pairwise[pair] = aggregate_fold_metrics(pair_fold_metrics)
|
||
|
||
# Load existing log to preserve training history
|
||
log_path = logs_dir / f"{run_name}.json"
|
||
if log_path.exists():
|
||
with open(log_path) as f:
|
||
existing = json.load(f)
|
||
# Keep the training history from the original log
|
||
for fr_new in fold_results:
|
||
for fr_old in existing.get("fold_results", []):
|
||
if fr_old["fold"] == fr_new["fold"]:
|
||
fr_new["history"] = fr_old.get("history")
|
||
break
|
||
else:
|
||
existing = {}
|
||
|
||
results = {
|
||
"run_name": run_name,
|
||
"n_folds": n_folds,
|
||
"fold_results": fold_results,
|
||
"aggregated_metrics": aggregated,
|
||
"aggregated_per_source": aggregated_per_source,
|
||
"aggregated_pairwise": aggregated_pairwise,
|
||
"config": existing.get("config", cfg),
|
||
}
|
||
|
||
with open(log_path, "w") as f:
|
||
json.dump(results, f, indent=2)
|
||
print(f" Saved: {log_path}")
|
||
|
||
print("\nDone.")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|