Clean state
This commit is contained in:
@@ -0,0 +1,288 @@
|
||||
"""
|
||||
Re-evaluate existing trained checkpoints with per-source metrics.
|
||||
|
||||
Loads each config, rebuilds CV splits (deterministic), loads the _best.pt
|
||||
checkpoint per fold, runs predict_rows, and writes updated log files
|
||||
with aggregate + per-source + pairwise metrics.
|
||||
|
||||
Usage:
|
||||
python tools/reevaluate.py # re-evaluate all experiments
|
||||
python tools/reevaluate.py p1_resnet18_baseline # specific experiments
|
||||
python tools/reevaluate.py --data-dir /mnt/data/DFF
|
||||
python tools/reevaluate.py --use-gpu
|
||||
"""
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
warnings.filterwarnings("ignore", message="Corrupt EXIF data", category=UserWarning)
|
||||
warnings.filterwarnings("ignore", message=".*weights_only.*", category=FutureWarning)
|
||||
|
||||
# Ensure classifier/ is on sys.path so `src.*` imports work
|
||||
ROOT = Path(__file__).resolve().parent.parent
|
||||
if str(ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(ROOT))
|
||||
|
||||
# Config paths (data_dir) are relative to the project root
|
||||
PROJECT_ROOT = ROOT.parent
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
"run_names", nargs="*",
|
||||
help="Run names to re-evaluate (matches log filenames). Default: all.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data-dir", default=None,
|
||||
help="Override cfg['data_dir'] for this run.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-root", default="outputs",
|
||||
help="Directory where models/logs live. Default: outputs",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use-gpu", action="store_true",
|
||||
help="Use GPU for evaluation.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
# Map run_name -> config path (relative to classifier/)
|
||||
CONFIG_MAP = {
|
||||
# Phase 1
|
||||
"p1_resnet18_baseline": "configs/phase1/p1_resnet18_baseline.json",
|
||||
"p1_simplecnn_baseline": "configs/phase1/p1_simplecnn_baseline.json",
|
||||
# Phase 2a – shortcut / holdout
|
||||
"p2a_t1_original": "configs/phase2/p2a_t1_original.json",
|
||||
"p2a_t2_real_norm": "configs/phase2/p2a_t2_real_norm.json",
|
||||
"p2a_t3_holdout_text2img": "configs/phase2/p2a_t3_holdout_text2img.json",
|
||||
"p2a_t3_holdout_inpainting": "configs/phase2/p2a_t3_holdout_inpainting.json",
|
||||
"p2a_t3_holdout_insight": "configs/phase2/p2a_t3_holdout_insight.json",
|
||||
# Phase 2b – resolution
|
||||
"p2b_resnet18_224": "configs/phase2/p2b_resnet18_224.json",
|
||||
"p2b_simplecnn_224": "configs/phase2/p2b_simplecnn_224.json",
|
||||
# Phase 2c – face crop
|
||||
"p2c_resnet18_facecrop": "configs/phase2/p2c_resnet18_facecrop.json",
|
||||
"p2c_simplecnn_facecrop": "configs/phase2/p2c_simplecnn_facecrop.json",
|
||||
# Phase 2d – augmentation
|
||||
"p2d_resnet18_aug": "configs/phase2/p2d_resnet18_aug.json",
|
||||
"p2d_simplecnn_aug": "configs/phase2/p2d_simplecnn_aug.json",
|
||||
# Phase 2e – face crop + aug
|
||||
"p2e_resnet18_facecrop_aug": "configs/phase2/p2e_resnet18_facecrop_aug.json",
|
||||
"p2e_simplecnn_facecrop_aug": "configs/phase2/p2e_simplecnn_facecrop_aug.json",
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
import numpy as np
|
||||
import torch
|
||||
from src.data import DFFDataset, apply_subsample, build_transforms, get_splits
|
||||
from src.evaluation.evaluate import predict_rows
|
||||
from src.evaluation.metrics import calc_metrics, source_metrics, pair_metrics
|
||||
from src.models import get_model
|
||||
from src.utils import load_config
|
||||
from src.utils.cross_validation import aggregate_fold_metrics
|
||||
|
||||
output_root = Path(args.output_root)
|
||||
logs_dir = output_root / "logs"
|
||||
models_dir = output_root / "models"
|
||||
|
||||
# Determine which experiments to re-evaluate
|
||||
if args.run_names:
|
||||
run_names = args.run_names
|
||||
else:
|
||||
run_names = sorted(
|
||||
p.stem for p in logs_dir.glob("*.json")
|
||||
if p.stem in CONFIG_MAP
|
||||
)
|
||||
|
||||
device = "cuda" if args.use_gpu and torch.cuda.is_available() else "cpu"
|
||||
print(f"Device: {device}")
|
||||
|
||||
for run_name in run_names:
|
||||
config_rel = CONFIG_MAP.get(run_name)
|
||||
if config_rel is None:
|
||||
print(f"\nSkipping {run_name}: no config mapping")
|
||||
continue
|
||||
|
||||
config_path = ROOT / config_rel
|
||||
if not config_path.exists():
|
||||
print(f"\nSkipping {run_name}: config not found ({config_rel})")
|
||||
continue
|
||||
|
||||
# Check that at least one checkpoint exists
|
||||
checkpoints = sorted(models_dir.glob(f"{run_name}_fold*_best.pt"))
|
||||
if not checkpoints:
|
||||
print(f"\nSkipping {run_name}: no checkpoints found")
|
||||
continue
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Re-evaluating: {run_name}")
|
||||
print(f" Config: {config_rel}")
|
||||
print(f" Checkpoints: {len(checkpoints)}")
|
||||
print(f"{'='*60}")
|
||||
|
||||
# Load config
|
||||
cfg = load_config(config_path)
|
||||
seed = cfg.get("seed", 42)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
deterministic = cfg.get("deterministic", False)
|
||||
torch.backends.cudnn.deterministic = deterministic
|
||||
torch.backends.cudnn.benchmark = not deterministic
|
||||
|
||||
data_dir = args.data_dir or cfg.get("data_dir", "data")
|
||||
# Config paths are relative to the project root, not classifier/
|
||||
if not Path(data_dir).is_absolute():
|
||||
data_dir = str(PROJECT_ROOT / data_dir)
|
||||
|
||||
# Build dataset
|
||||
raw_ds = DFFDataset(data_dir, sources=cfg.get("dataset_sources"))
|
||||
|
||||
sampled = apply_subsample(raw_ds, cfg)
|
||||
if sampled is not None:
|
||||
n_samples, total = sampled
|
||||
print(f" Subsampled to {n_samples}/{total} samples")
|
||||
|
||||
# Build CV splits and transforms (deterministic – same as training)
|
||||
splits = get_splits(raw_ds, cfg)
|
||||
transform_builder = build_transforms(raw_ds, cfg, augment=cfg.get("augment"))
|
||||
|
||||
n_folds = len(splits)
|
||||
fold_results = []
|
||||
|
||||
for fold_idx in range(n_folds):
|
||||
train_idx, val_idx, test_idx = splits[fold_idx]
|
||||
|
||||
checkpoint_path = models_dir / f"{run_name}_fold{fold_idx}_best.pt"
|
||||
if not checkpoint_path.exists():
|
||||
print(f" Fold {fold_idx}: checkpoint missing, skipping")
|
||||
continue
|
||||
|
||||
# Rebuild model and load checkpoint
|
||||
model = get_model(cfg)
|
||||
model.load_state_dict(
|
||||
torch.load(checkpoint_path, map_location=device, weights_only=True)
|
||||
)
|
||||
model.to(device).eval()
|
||||
|
||||
# Build test dataset for this fold
|
||||
if cfg.get("normalization") == "real_norm":
|
||||
from src.preprocessing.pipeline import compute_real_stats
|
||||
norm_mean, norm_std = compute_real_stats(raw_ds, train_idx)
|
||||
else:
|
||||
norm_mean = norm_std = None
|
||||
|
||||
test_dataset = transform_builder(
|
||||
test_idx, train=False,
|
||||
normalize_mean=norm_mean, normalize_std=norm_std,
|
||||
)
|
||||
|
||||
records = predict_rows(
|
||||
model, test_dataset, raw_ds, test_idx,
|
||||
cfg["batch_size"], device, num_workers=4,
|
||||
)
|
||||
|
||||
# Compute metrics
|
||||
test_metrics = calc_metrics(records)
|
||||
src_metrics = source_metrics(records)
|
||||
pairwise = pair_metrics(records)
|
||||
|
||||
fold_result = {
|
||||
"fold": fold_idx,
|
||||
"train_size": len(train_idx),
|
||||
"val_size": len(val_idx),
|
||||
"test_size": len(test_idx),
|
||||
"test_metrics": test_metrics,
|
||||
"source_metrics": src_metrics,
|
||||
"pair_metrics": pairwise,
|
||||
}
|
||||
fold_results.append(fold_result)
|
||||
|
||||
print(f" Fold {fold_idx}: auc={test_metrics.get('auc_roc', '?'):.4f} "
|
||||
f"acc={test_metrics.get('accuracy', '?'):.4f} "
|
||||
f"f1={test_metrics.get('f1', '?'):.4f}")
|
||||
for source, sm in sorted(src_metrics.items()):
|
||||
pa = sm.get("pairwise_auc")
|
||||
dr = sm.get("detection_rate")
|
||||
label = (f"pairwise_auc={pa:.4f}" if pa is not None
|
||||
else f"detection_rate={dr:.4f}" if dr is not None else "")
|
||||
print(f" {source}: {label}")
|
||||
|
||||
if not fold_results:
|
||||
print(f" No folds evaluated for {run_name}")
|
||||
continue
|
||||
|
||||
# Aggregate across folds
|
||||
test_metrics_list = [f["test_metrics"] for f in fold_results]
|
||||
aggregated = aggregate_fold_metrics(test_metrics_list)
|
||||
|
||||
# Aggregate per-source metrics
|
||||
all_sources = sorted({s for f in fold_results for s in f["source_metrics"]})
|
||||
aggregated_per_source = {}
|
||||
for source in all_sources:
|
||||
source_fold_metrics = []
|
||||
for f in fold_results:
|
||||
sm = f["source_metrics"].get(source)
|
||||
if sm:
|
||||
source_fold_metrics.append({
|
||||
k: v for k, v in sm.items()
|
||||
if isinstance(v, (int, float)) and k != "fold"
|
||||
})
|
||||
if source_fold_metrics:
|
||||
aggregated_per_source[source] = aggregate_fold_metrics(source_fold_metrics)
|
||||
|
||||
# Aggregate pairwise metrics
|
||||
all_pairs = sorted({p for f in fold_results for p in f["pair_metrics"]})
|
||||
aggregated_pairwise = {}
|
||||
for pair in all_pairs:
|
||||
pair_fold_metrics = []
|
||||
for f in fold_results:
|
||||
pm = f["pair_metrics"].get(pair)
|
||||
if pm:
|
||||
pair_fold_metrics.append({
|
||||
k: v for k, v in pm.items()
|
||||
if isinstance(v, (int, float)) and k not in ("fold", "n")
|
||||
})
|
||||
if pair_fold_metrics:
|
||||
aggregated_pairwise[pair] = aggregate_fold_metrics(pair_fold_metrics)
|
||||
|
||||
# Load existing log to preserve training history
|
||||
log_path = logs_dir / f"{run_name}.json"
|
||||
if log_path.exists():
|
||||
with open(log_path) as f:
|
||||
existing = json.load(f)
|
||||
# Keep the training history from the original log
|
||||
for fr_new in fold_results:
|
||||
for fr_old in existing.get("fold_results", []):
|
||||
if fr_old["fold"] == fr_new["fold"]:
|
||||
fr_new["history"] = fr_old.get("history")
|
||||
break
|
||||
else:
|
||||
existing = {}
|
||||
|
||||
results = {
|
||||
"run_name": run_name,
|
||||
"n_folds": n_folds,
|
||||
"fold_results": fold_results,
|
||||
"aggregated_metrics": aggregated,
|
||||
"aggregated_per_source": aggregated_per_source,
|
||||
"aggregated_pairwise": aggregated_pairwise,
|
||||
"config": existing.get("config", cfg),
|
||||
}
|
||||
|
||||
with open(log_path, "w") as f:
|
||||
json.dump(results, f, indent=2)
|
||||
print(f" Saved: {log_path}")
|
||||
|
||||
print("\nDone.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user