Files
DRL_PROJ/classifier/src/utils/cross_validation.py
T
Johnny Fernandes bb3dfb92d5 Clean state
2026-04-30 01:25:39 +01:00

86 lines
3.0 KiB
Python

from pathlib import Path
from typing import Any, Dict, List, Tuple
import numpy as np
from sklearn.model_selection import StratifiedGroupKFold
# Groups by filename stem so sibling images of the same identity (same name
# across wiki/inpainting/text2img/insight) always stay in the same fold
def get_basename(path: str) -> str:
return Path(path).stem
# ── Splits ─────────────────────────────────────────────────────────────────
# Outer fold: StratifiedGroupKFold holds out one fold as test.
# Inner val: 10% of remaining groups are held out randomly — no per-class
# stratification needed since every DFF basename is multi-source (mixed label).
def create_group_kfold_splits(
samples: List[Tuple[str, int]],
n_splits: int = 5,
seed: int = 42,
) -> List[Tuple[List[int], List[int], List[int]]]:
paths = [s[0] for s in samples]
labels = np.array([s[1] for s in samples])
groups = np.array([get_basename(p) for p in paths])
sgkf = StratifiedGroupKFold(n_splits=n_splits, shuffle=True, random_state=seed)
splits = []
for fold_idx, (train_val_idx, test_idx) in enumerate(sgkf.split(paths, labels, groups)):
train_val_groups = groups[train_val_idx]
unique_groups = np.unique(train_val_groups)
n_val_groups = max(1, int(len(unique_groups) * 0.1))
rng = np.random.RandomState(seed + fold_idx)
val_groups = set(rng.choice(unique_groups, n_val_groups, replace=False))
train_idx = []
val_idx = []
for i, g in enumerate(train_val_groups):
if g in val_groups:
val_idx.append(train_val_idx[i])
else:
train_idx.append(train_val_idx[i])
splits.append((train_idx, val_idx, test_idx.tolist()))
return splits
# ── Aggregation ────────────────────────────────────────────────────────────
# Infers numeric keys from the first fold if metric_keys is not supplied;
# uses sample std (ddof=1) and normal-approximation 95% CI
def aggregate_fold_metrics(
fold_metrics: List[Dict[str, Any]],
metric_keys: List[str] = None,
) -> Dict[str, Any]:
if metric_keys is None:
metric_keys = [
k for k, v in fold_metrics[0].items()
if isinstance(v, (int, float)) and not isinstance(v, bool)
]
aggregated = {}
for key in metric_keys:
values = [fold[key] for fold in fold_metrics if key in fold]
if not values:
continue
values = np.array(values)
mean = np.mean(values)
std = np.std(values, ddof=1)
n = len(values)
ci_95 = 1.96 * std / np.sqrt(n) if n > 1 else 0.0
aggregated[key] = {
"mean": float(mean),
"std": float(std),
"ci_95": float(ci_95),
"values": values.tolist(),
}
return aggregated