Clean state

This commit is contained in:
Johnny Fernandes
2026-04-30 01:25:39 +01:00
commit bb3dfb92d5
266 changed files with 37043 additions and 0 deletions
+288
View File
@@ -0,0 +1,288 @@
"""
Re-evaluate existing trained checkpoints with per-source metrics.
Loads each config, rebuilds CV splits (deterministic), loads the _best.pt
checkpoint per fold, runs predict_rows, and writes updated log files
with aggregate + per-source + pairwise metrics.
Usage:
python tools/reevaluate.py # re-evaluate all experiments
python tools/reevaluate.py p1_resnet18_baseline # specific experiments
python tools/reevaluate.py --data-dir /mnt/data/DFF
python tools/reevaluate.py --use-gpu
"""
import argparse
import json
import sys
import warnings
from pathlib import Path
warnings.filterwarnings("ignore", message="Corrupt EXIF data", category=UserWarning)
warnings.filterwarnings("ignore", message=".*weights_only.*", category=FutureWarning)
# Ensure classifier/ is on sys.path so `src.*` imports work
ROOT = Path(__file__).resolve().parent.parent
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
# Config paths (data_dir) are relative to the project root
PROJECT_ROOT = ROOT.parent
def parse_args():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"run_names", nargs="*",
help="Run names to re-evaluate (matches log filenames). Default: all.",
)
parser.add_argument(
"--data-dir", default=None,
help="Override cfg['data_dir'] for this run.",
)
parser.add_argument(
"--output-root", default="outputs",
help="Directory where models/logs live. Default: outputs",
)
parser.add_argument(
"--use-gpu", action="store_true",
help="Use GPU for evaluation.",
)
return parser.parse_args()
# Map run_name -> config path (relative to classifier/)
CONFIG_MAP = {
# Phase 1
"p1_resnet18_baseline": "configs/phase1/p1_resnet18_baseline.json",
"p1_simplecnn_baseline": "configs/phase1/p1_simplecnn_baseline.json",
# Phase 2a shortcut / holdout
"p2a_t1_original": "configs/phase2/p2a_t1_original.json",
"p2a_t2_real_norm": "configs/phase2/p2a_t2_real_norm.json",
"p2a_t3_holdout_text2img": "configs/phase2/p2a_t3_holdout_text2img.json",
"p2a_t3_holdout_inpainting": "configs/phase2/p2a_t3_holdout_inpainting.json",
"p2a_t3_holdout_insight": "configs/phase2/p2a_t3_holdout_insight.json",
# Phase 2b resolution
"p2b_resnet18_224": "configs/phase2/p2b_resnet18_224.json",
"p2b_simplecnn_224": "configs/phase2/p2b_simplecnn_224.json",
# Phase 2c face crop
"p2c_resnet18_facecrop": "configs/phase2/p2c_resnet18_facecrop.json",
"p2c_simplecnn_facecrop": "configs/phase2/p2c_simplecnn_facecrop.json",
# Phase 2d augmentation
"p2d_resnet18_aug": "configs/phase2/p2d_resnet18_aug.json",
"p2d_simplecnn_aug": "configs/phase2/p2d_simplecnn_aug.json",
# Phase 2e face crop + aug
"p2e_resnet18_facecrop_aug": "configs/phase2/p2e_resnet18_facecrop_aug.json",
"p2e_simplecnn_facecrop_aug": "configs/phase2/p2e_simplecnn_facecrop_aug.json",
}
def main():
args = parse_args()
import numpy as np
import torch
from src.data import DFFDataset, apply_subsample, build_transforms, get_splits
from src.evaluation.evaluate import predict_rows
from src.evaluation.metrics import calc_metrics, source_metrics, pair_metrics
from src.models import get_model
from src.utils import load_config
from src.utils.cross_validation import aggregate_fold_metrics
output_root = Path(args.output_root)
logs_dir = output_root / "logs"
models_dir = output_root / "models"
# Determine which experiments to re-evaluate
if args.run_names:
run_names = args.run_names
else:
run_names = sorted(
p.stem for p in logs_dir.glob("*.json")
if p.stem in CONFIG_MAP
)
device = "cuda" if args.use_gpu and torch.cuda.is_available() else "cpu"
print(f"Device: {device}")
for run_name in run_names:
config_rel = CONFIG_MAP.get(run_name)
if config_rel is None:
print(f"\nSkipping {run_name}: no config mapping")
continue
config_path = ROOT / config_rel
if not config_path.exists():
print(f"\nSkipping {run_name}: config not found ({config_rel})")
continue
# Check that at least one checkpoint exists
checkpoints = sorted(models_dir.glob(f"{run_name}_fold*_best.pt"))
if not checkpoints:
print(f"\nSkipping {run_name}: no checkpoints found")
continue
print(f"\n{'='*60}")
print(f"Re-evaluating: {run_name}")
print(f" Config: {config_rel}")
print(f" Checkpoints: {len(checkpoints)}")
print(f"{'='*60}")
# Load config
cfg = load_config(config_path)
seed = cfg.get("seed", 42)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
deterministic = cfg.get("deterministic", False)
torch.backends.cudnn.deterministic = deterministic
torch.backends.cudnn.benchmark = not deterministic
data_dir = args.data_dir or cfg.get("data_dir", "data")
# Config paths are relative to the project root, not classifier/
if not Path(data_dir).is_absolute():
data_dir = str(PROJECT_ROOT / data_dir)
# Build dataset
raw_ds = DFFDataset(data_dir, sources=cfg.get("dataset_sources"))
sampled = apply_subsample(raw_ds, cfg)
if sampled is not None:
n_samples, total = sampled
print(f" Subsampled to {n_samples}/{total} samples")
# Build CV splits and transforms (deterministic same as training)
splits = get_splits(raw_ds, cfg)
transform_builder = build_transforms(raw_ds, cfg, augment=cfg.get("augment"))
n_folds = len(splits)
fold_results = []
for fold_idx in range(n_folds):
train_idx, val_idx, test_idx = splits[fold_idx]
checkpoint_path = models_dir / f"{run_name}_fold{fold_idx}_best.pt"
if not checkpoint_path.exists():
print(f" Fold {fold_idx}: checkpoint missing, skipping")
continue
# Rebuild model and load checkpoint
model = get_model(cfg)
model.load_state_dict(
torch.load(checkpoint_path, map_location=device, weights_only=True)
)
model.to(device).eval()
# Build test dataset for this fold
if cfg.get("normalization") == "real_norm":
from src.preprocessing.pipeline import compute_real_stats
norm_mean, norm_std = compute_real_stats(raw_ds, train_idx)
else:
norm_mean = norm_std = None
test_dataset = transform_builder(
test_idx, train=False,
normalize_mean=norm_mean, normalize_std=norm_std,
)
records = predict_rows(
model, test_dataset, raw_ds, test_idx,
cfg["batch_size"], device, num_workers=4,
)
# Compute metrics
test_metrics = calc_metrics(records)
src_metrics = source_metrics(records)
pairwise = pair_metrics(records)
fold_result = {
"fold": fold_idx,
"train_size": len(train_idx),
"val_size": len(val_idx),
"test_size": len(test_idx),
"test_metrics": test_metrics,
"source_metrics": src_metrics,
"pair_metrics": pairwise,
}
fold_results.append(fold_result)
print(f" Fold {fold_idx}: auc={test_metrics.get('auc_roc', '?'):.4f} "
f"acc={test_metrics.get('accuracy', '?'):.4f} "
f"f1={test_metrics.get('f1', '?'):.4f}")
for source, sm in sorted(src_metrics.items()):
pa = sm.get("pairwise_auc")
dr = sm.get("detection_rate")
label = (f"pairwise_auc={pa:.4f}" if pa is not None
else f"detection_rate={dr:.4f}" if dr is not None else "")
print(f" {source}: {label}")
if not fold_results:
print(f" No folds evaluated for {run_name}")
continue
# Aggregate across folds
test_metrics_list = [f["test_metrics"] for f in fold_results]
aggregated = aggregate_fold_metrics(test_metrics_list)
# Aggregate per-source metrics
all_sources = sorted({s for f in fold_results for s in f["source_metrics"]})
aggregated_per_source = {}
for source in all_sources:
source_fold_metrics = []
for f in fold_results:
sm = f["source_metrics"].get(source)
if sm:
source_fold_metrics.append({
k: v for k, v in sm.items()
if isinstance(v, (int, float)) and k != "fold"
})
if source_fold_metrics:
aggregated_per_source[source] = aggregate_fold_metrics(source_fold_metrics)
# Aggregate pairwise metrics
all_pairs = sorted({p for f in fold_results for p in f["pair_metrics"]})
aggregated_pairwise = {}
for pair in all_pairs:
pair_fold_metrics = []
for f in fold_results:
pm = f["pair_metrics"].get(pair)
if pm:
pair_fold_metrics.append({
k: v for k, v in pm.items()
if isinstance(v, (int, float)) and k not in ("fold", "n")
})
if pair_fold_metrics:
aggregated_pairwise[pair] = aggregate_fold_metrics(pair_fold_metrics)
# Load existing log to preserve training history
log_path = logs_dir / f"{run_name}.json"
if log_path.exists():
with open(log_path) as f:
existing = json.load(f)
# Keep the training history from the original log
for fr_new in fold_results:
for fr_old in existing.get("fold_results", []):
if fr_old["fold"] == fr_new["fold"]:
fr_new["history"] = fr_old.get("history")
break
else:
existing = {}
results = {
"run_name": run_name,
"n_folds": n_folds,
"fold_results": fold_results,
"aggregated_metrics": aggregated,
"aggregated_per_source": aggregated_per_source,
"aggregated_pairwise": aggregated_pairwise,
"config": existing.get("config", cfg),
}
with open(log_path, "w") as f:
json.dump(results, f, indent=2)
print(f" Saved: {log_path}")
print("\nDone.")
if __name__ == "__main__":
main()