Clean state
This commit is contained in:
@@ -0,0 +1,85 @@
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
from sklearn.model_selection import StratifiedGroupKFold
|
||||
|
||||
|
||||
# Groups by filename stem so sibling images of the same identity (same name
|
||||
# across wiki/inpainting/text2img/insight) always stay in the same fold
|
||||
def get_basename(path: str) -> str:
|
||||
return Path(path).stem
|
||||
|
||||
|
||||
# ── Splits ─────────────────────────────────────────────────────────────────
|
||||
|
||||
# Outer fold: StratifiedGroupKFold holds out one fold as test.
|
||||
# Inner val: 10% of remaining groups are held out randomly — no per-class
|
||||
# stratification needed since every DFF basename is multi-source (mixed label).
|
||||
def create_group_kfold_splits(
|
||||
samples: List[Tuple[str, int]],
|
||||
n_splits: int = 5,
|
||||
seed: int = 42,
|
||||
) -> List[Tuple[List[int], List[int], List[int]]]:
|
||||
paths = [s[0] for s in samples]
|
||||
labels = np.array([s[1] for s in samples])
|
||||
groups = np.array([get_basename(p) for p in paths])
|
||||
|
||||
sgkf = StratifiedGroupKFold(n_splits=n_splits, shuffle=True, random_state=seed)
|
||||
splits = []
|
||||
|
||||
for fold_idx, (train_val_idx, test_idx) in enumerate(sgkf.split(paths, labels, groups)):
|
||||
train_val_groups = groups[train_val_idx]
|
||||
unique_groups = np.unique(train_val_groups)
|
||||
n_val_groups = max(1, int(len(unique_groups) * 0.1))
|
||||
|
||||
rng = np.random.RandomState(seed + fold_idx)
|
||||
val_groups = set(rng.choice(unique_groups, n_val_groups, replace=False))
|
||||
|
||||
train_idx = []
|
||||
val_idx = []
|
||||
for i, g in enumerate(train_val_groups):
|
||||
if g in val_groups:
|
||||
val_idx.append(train_val_idx[i])
|
||||
else:
|
||||
train_idx.append(train_val_idx[i])
|
||||
|
||||
splits.append((train_idx, val_idx, test_idx.tolist()))
|
||||
|
||||
return splits
|
||||
|
||||
|
||||
# ── Aggregation ────────────────────────────────────────────────────────────
|
||||
|
||||
# Infers numeric keys from the first fold if metric_keys is not supplied;
|
||||
# uses sample std (ddof=1) and normal-approximation 95% CI
|
||||
def aggregate_fold_metrics(
|
||||
fold_metrics: List[Dict[str, Any]],
|
||||
metric_keys: List[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
if metric_keys is None:
|
||||
metric_keys = [
|
||||
k for k, v in fold_metrics[0].items()
|
||||
if isinstance(v, (int, float)) and not isinstance(v, bool)
|
||||
]
|
||||
|
||||
aggregated = {}
|
||||
for key in metric_keys:
|
||||
values = [fold[key] for fold in fold_metrics if key in fold]
|
||||
if not values:
|
||||
continue
|
||||
values = np.array(values)
|
||||
mean = np.mean(values)
|
||||
std = np.std(values, ddof=1)
|
||||
n = len(values)
|
||||
ci_95 = 1.96 * std / np.sqrt(n) if n > 1 else 0.0
|
||||
aggregated[key] = {
|
||||
"mean": float(mean),
|
||||
"std": float(std),
|
||||
"ci_95": float(ci_95),
|
||||
"values": values.tolist(),
|
||||
}
|
||||
|
||||
return aggregated
|
||||
|
||||
|
||||
Reference in New Issue
Block a user