from pathlib import Path from typing import Any, Dict, List, Tuple import numpy as np from sklearn.model_selection import StratifiedGroupKFold # Groups by filename stem so sibling images of the same identity (same name # across wiki/inpainting/text2img/insight) always stay in the same fold def get_basename(path: str) -> str: return Path(path).stem # ── Splits ───────────────────────────────────────────────────────────────── # Outer fold: StratifiedGroupKFold holds out one fold as test. # Inner val: 10% of remaining groups are held out randomly — no per-class # stratification needed since every DFF basename is multi-source (mixed label). def create_group_kfold_splits( samples: List[Tuple[str, int]], n_splits: int = 5, seed: int = 42, ) -> List[Tuple[List[int], List[int], List[int]]]: paths = [s[0] for s in samples] labels = np.array([s[1] for s in samples]) groups = np.array([get_basename(p) for p in paths]) sgkf = StratifiedGroupKFold(n_splits=n_splits, shuffle=True, random_state=seed) splits = [] for fold_idx, (train_val_idx, test_idx) in enumerate(sgkf.split(paths, labels, groups)): train_val_groups = groups[train_val_idx] unique_groups = np.unique(train_val_groups) n_val_groups = max(1, int(len(unique_groups) * 0.1)) rng = np.random.RandomState(seed + fold_idx) val_groups = set(rng.choice(unique_groups, n_val_groups, replace=False)) train_idx = [] val_idx = [] for i, g in enumerate(train_val_groups): if g in val_groups: val_idx.append(train_val_idx[i]) else: train_idx.append(train_val_idx[i]) splits.append((train_idx, val_idx, test_idx.tolist())) return splits # ── Aggregation ──────────────────────────────────────────────────────────── # Infers numeric keys from the first fold if metric_keys is not supplied; # uses sample std (ddof=1) and normal-approximation 95% CI def aggregate_fold_metrics( fold_metrics: List[Dict[str, Any]], metric_keys: List[str] = None, ) -> Dict[str, Any]: if metric_keys is None: metric_keys = [ k for k, v in fold_metrics[0].items() if isinstance(v, (int, float)) and not isinstance(v, bool) ] aggregated = {} for key in metric_keys: values = [fold[key] for fold in fold_metrics if key in fold] if not values: continue values = np.array(values) mean = np.mean(values) std = np.std(values, ddof=1) n = len(values) ci_95 = 1.96 * std / np.sqrt(n) if n > 1 else 0.0 aggregated[key] = { "mean": float(mean), "std": float(std), "ci_95": float(ci_95), "values": values.tolist(), } return aggregated