86 lines
3.0 KiB
Python
86 lines
3.0 KiB
Python
from pathlib import Path
|
|
from typing import Any, Dict, List, Tuple
|
|
|
|
import numpy as np
|
|
from sklearn.model_selection import StratifiedGroupKFold
|
|
|
|
|
|
# Groups by filename stem so sibling images of the same identity (same name
|
|
# across wiki/inpainting/text2img/insight) always stay in the same fold
|
|
def get_basename(path: str) -> str:
|
|
return Path(path).stem
|
|
|
|
|
|
# ── Splits ─────────────────────────────────────────────────────────────────
|
|
|
|
# Outer fold: StratifiedGroupKFold holds out one fold as test.
|
|
# Inner val: 10% of remaining groups are held out randomly — no per-class
|
|
# stratification needed since every DFF basename is multi-source (mixed label).
|
|
def create_group_kfold_splits(
|
|
samples: List[Tuple[str, int]],
|
|
n_splits: int = 5,
|
|
seed: int = 42,
|
|
) -> List[Tuple[List[int], List[int], List[int]]]:
|
|
paths = [s[0] for s in samples]
|
|
labels = np.array([s[1] for s in samples])
|
|
groups = np.array([get_basename(p) for p in paths])
|
|
|
|
sgkf = StratifiedGroupKFold(n_splits=n_splits, shuffle=True, random_state=seed)
|
|
splits = []
|
|
|
|
for fold_idx, (train_val_idx, test_idx) in enumerate(sgkf.split(paths, labels, groups)):
|
|
train_val_groups = groups[train_val_idx]
|
|
unique_groups = np.unique(train_val_groups)
|
|
n_val_groups = max(1, int(len(unique_groups) * 0.1))
|
|
|
|
rng = np.random.RandomState(seed + fold_idx)
|
|
val_groups = set(rng.choice(unique_groups, n_val_groups, replace=False))
|
|
|
|
train_idx = []
|
|
val_idx = []
|
|
for i, g in enumerate(train_val_groups):
|
|
if g in val_groups:
|
|
val_idx.append(train_val_idx[i])
|
|
else:
|
|
train_idx.append(train_val_idx[i])
|
|
|
|
splits.append((train_idx, val_idx, test_idx.tolist()))
|
|
|
|
return splits
|
|
|
|
|
|
# ── Aggregation ────────────────────────────────────────────────────────────
|
|
|
|
# Infers numeric keys from the first fold if metric_keys is not supplied;
|
|
# uses sample std (ddof=1) and normal-approximation 95% CI
|
|
def aggregate_fold_metrics(
|
|
fold_metrics: List[Dict[str, Any]],
|
|
metric_keys: List[str] = None,
|
|
) -> Dict[str, Any]:
|
|
if metric_keys is None:
|
|
metric_keys = [
|
|
k for k, v in fold_metrics[0].items()
|
|
if isinstance(v, (int, float)) and not isinstance(v, bool)
|
|
]
|
|
|
|
aggregated = {}
|
|
for key in metric_keys:
|
|
values = [fold[key] for fold in fold_metrics if key in fold]
|
|
if not values:
|
|
continue
|
|
values = np.array(values)
|
|
mean = np.mean(values)
|
|
std = np.std(values, ddof=1)
|
|
n = len(values)
|
|
ci_95 = 1.96 * std / np.sqrt(n) if n > 1 else 0.0
|
|
aggregated[key] = {
|
|
"mean": float(mean),
|
|
"std": float(std),
|
|
"ci_95": float(ci_95),
|
|
"values": values.tolist(),
|
|
}
|
|
|
|
return aggregated
|
|
|
|
|