Clean state

This commit is contained in:
Johnny Fernandes
2026-04-30 01:25:39 +01:00
commit bb3dfb92d5
266 changed files with 37043 additions and 0 deletions
+3
View File
@@ -0,0 +1,3 @@
from src.training.trainer import train_classifier, train_classifier_cv
__all__ = ["train_classifier", "train_classifier_cv"]
+374
View File
@@ -0,0 +1,374 @@
import json
from collections import Counter
from pathlib import Path
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Subset
from tqdm import tqdm
from src.evaluation.metrics import binary_metrics
# ── AMP compatibility shim ─────────────────────────────────────────────────
# torch.amp.GradScaler / autocast moved from torch.cuda.amp in PyTorch 2.3+
if hasattr(torch.amp, "GradScaler"):
_GradScaler = torch.amp.GradScaler
_autocast = torch.amp.autocast
else:
from torch.cuda.amp import GradScaler as _OldGradScaler, autocast as _OldAutocast
_GradScaler = lambda device="", enabled=True, **kw: _OldGradScaler(enabled=enabled, **kw)
_autocast = lambda device_type="", enabled=True, **kw: _OldAutocast(enabled=enabled, **kw)
# ── Single-fold training ───────────────────────────────────────────────────
# Trains one fold; saves best checkpoint by val AUC-ROC and final checkpoint.
# pos_weight is passed through to BCEWithLogitsLoss to handle class imbalance.
def train_classifier(
model,
train_dataset,
val_dataset,
*,
epochs=10,
batch_size=16,
lr=1e-4,
weight_decay=1e-4,
device="cuda",
save_dir="outputs/models",
run_name="classifier",
early_stopping_patience=0,
num_workers=4,
grad_clip_norm=1.0,
T_max=None,
pos_weight=None,
):
device = torch.device(device)
if device.type == "cuda":
print(f"Using GPU: {torch.cuda.get_device_name(0)}")
else:
print("Using CPU")
use_amp = device.type == "cuda"
model = model.to(device)
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Trainable parameters: {n_params:,}")
train_loader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=True,
num_workers=num_workers, pin_memory=True
)
val_loader = DataLoader(
val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True
)
pw = torch.tensor([pos_weight], device=device) if pos_weight is not None else None
criterion = nn.BCEWithLogitsLoss(pos_weight=pw)
optimizer = torch.optim.AdamW(
filter(lambda p: p.requires_grad, model.parameters()),
lr=lr, weight_decay=weight_decay,
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=T_max or epochs)
scaler = _GradScaler("cuda", enabled=use_amp)
print(f"Device: {device} AMP: {use_amp}")
save_dir = Path(save_dir)
save_dir.mkdir(parents=True, exist_ok=True)
history = {
"train_loss": [], "train_acc": [], "train_auc": [], "train_f1": [],
"val_loss": [], "val_acc": [], "val_auc": [], "val_f1": [],
}
best_auc = 0.0
patience_counter = 0
for epoch in range(1, epochs + 1):
# ── train ──
model.train()
total_loss = 0.0
train_logits, train_labels = [], []
for images, labels in tqdm(train_loader, desc=f"Epoch {epoch}/{epochs} [train]", leave=False):
images = images.to(device)
labels = labels.float().to(device)
optimizer.zero_grad()
with _autocast("cuda", enabled=use_amp):
logits = model(images).squeeze(1)
loss = criterion(logits, labels)
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip_norm)
scaler.step(optimizer)
scaler.update()
total_loss += loss.item() * len(images)
train_logits.append(logits.detach().cpu())
train_labels.append(labels.detach().cpu())
train_loss = total_loss / sum(len(logit) for logit in train_logits)
train_m = binary_metrics(torch.cat(train_logits), torch.cat(train_labels))
scheduler.step()
# ── validate ──
model.eval()
val_loss = 0.0
all_logits, all_labels = [], []
with torch.no_grad():
for images, labels in tqdm(val_loader, desc=f"Epoch {epoch}/{epochs} [val]", leave=False):
images = images.to(device)
labels = labels.float().to(device)
with _autocast("cuda", enabled=use_amp):
logits = model(images).squeeze(1)
batch_loss = criterion(logits, labels)
if not (torch.isnan(batch_loss) or torch.isinf(batch_loss)):
val_loss += batch_loss.item() * len(images)
all_logits.append(logits.cpu())
all_labels.append(labels.cpu())
val_loss /= len(val_dataset)
val_m = binary_metrics(torch.cat(all_logits), torch.cat(all_labels))
# ── record ──
history["train_loss"].append(train_loss)
history["train_acc"].append(train_m["accuracy"])
history["train_auc"].append(train_m["auc_roc"])
history["train_f1"].append(train_m["f1"])
history["val_loss"].append(val_loss)
history["val_acc"].append(val_m["accuracy"])
history["val_auc"].append(val_m["auc_roc"])
history["val_f1"].append(val_m["f1"])
gap_loss = train_loss - val_loss
gap_acc = train_m["accuracy"] - val_m["accuracy"]
print(
f"[{epoch:03d}/{epochs}] "
f"loss: {train_loss:.4f}/{val_loss:.4f} (gap {gap_loss:+.4f}) "
f"acc: {train_m['accuracy']:.4f}/{val_m['accuracy']:.4f} (gap {gap_acc:+.4f}) "
f"auc: {train_m['auc_roc']:.4f}/{val_m['auc_roc']:.4f} "
f"f1: {train_m['f1']:.4f}/{val_m['f1']:.4f}"
)
# ── checkpoint ──
if val_m["auc_roc"] is not None and val_m["auc_roc"] > best_auc:
best_auc = val_m["auc_roc"]
torch.save(model.state_dict(), save_dir / f"{run_name}_best.pt")
patience_counter = 0
else:
patience_counter += 1
if early_stopping_patience > 0 and patience_counter >= early_stopping_patience:
print(f"Early stopping at epoch {epoch} (no improvement for {early_stopping_patience} epochs)")
break
torch.save(model.state_dict(), save_dir / f"{run_name}_final.pt")
return history
# ── CV training ────────────────────────────────────────────────────────────
# Iterates over pre-built splits, trains one model per fold, evaluates on the
# held-out test fold, then aggregates metrics across folds with mean ± std
def train_classifier_cv(
model_fn,
raw_dataset,
splits,
*,
epochs=10,
batch_size=16,
lr=1e-4,
weight_decay=1e-4,
device="cuda",
save_dir="outputs/models",
run_name="classifier_cv",
early_stopping_patience=0,
num_workers=4,
transform_builder=None,
grad_clip_norm=1.0,
T_max=None,
normalization=None,
logs_dir=None,
):
from src.evaluation.evaluate import (
predict_rows, save_errors, save_hists, save_preds, save_summary,
)
from src.evaluation.metrics import binary_metrics, calc_metrics, source_metrics, pair_metrics
from src.utils.cross_validation import aggregate_fold_metrics
device = torch.device(device if torch.cuda.is_available() else "cpu")
save_dir = Path(save_dir)
save_dir.mkdir(parents=True, exist_ok=True)
eval_dir = Path(logs_dir) / run_name if logs_dir is not None else None
if eval_dir is not None:
eval_dir.mkdir(parents=True, exist_ok=True)
fold_results = []
all_records = []
for fold_idx, (train_idx, val_idx, test_idx) in enumerate(splits):
print(f"\n{'='*60}")
print(f"Fold {fold_idx + 1}/{len(splits)}")
print(f" Train: {len(train_idx)} Val: {len(val_idx)} Test: {len(test_idx)}")
print(f"{'='*60}")
model = model_fn().to(device)
norm_mean = norm_std = None
if normalization == "real_norm":
from src.preprocessing.pipeline import compute_real_stats
norm_mean, norm_std = compute_real_stats(raw_dataset, train_idx)
print(f" Real-norm stats: mean={norm_mean}, std={norm_std}")
if transform_builder is not None:
train_dataset = transform_builder(train_idx, train=True,
normalize_mean=norm_mean, normalize_std=norm_std)
val_dataset = transform_builder(val_idx, train=False,
normalize_mean=norm_mean, normalize_std=norm_std)
test_dataset = transform_builder(test_idx, train=False,
normalize_mean=norm_mean, normalize_std=norm_std)
else:
train_dataset = Subset(raw_dataset, train_idx)
val_dataset = Subset(raw_dataset, val_idx)
test_dataset = Subset(raw_dataset, test_idx)
# Compute pos_weight = n_real / n_fake for BCEWithLogitsLoss class balancing
train_labels = [raw_dataset.samples[i][1] for i in train_idx]
class_counts = Counter(train_labels)
pos_weight = class_counts[0] / class_counts[1] if class_counts[1] > 0 else 1.0
fold_run_name = f"{run_name}_fold{fold_idx}"
history = train_classifier(
model,
train_dataset,
val_dataset,
epochs=epochs,
batch_size=batch_size,
lr=lr,
weight_decay=weight_decay,
device=device,
save_dir=save_dir,
run_name=fold_run_name,
early_stopping_patience=early_stopping_patience,
num_workers=num_workers,
grad_clip_norm=grad_clip_norm,
T_max=T_max,
pos_weight=pos_weight,
)
# Load best checkpoint and evaluate on test set
checkpoint_path = save_dir / f"{fold_run_name}_best.pt"
if checkpoint_path.exists():
model.load_state_dict(torch.load(checkpoint_path, map_location=device, weights_only=True))
model.eval()
records = predict_rows(
model, test_dataset, raw_dataset, test_idx,
batch_size, device, num_workers=num_workers,
)
# Compute aggregate and per-source test metrics
test_metrics = calc_metrics(records)
src_metrics = source_metrics(records)
pairwise = pair_metrics(records)
fold_result = {
"fold": fold_idx,
"train_size": len(train_idx),
"val_size": len(val_idx),
"test_size": len(test_idx),
"history": history,
"test_metrics": test_metrics,
"source_metrics": src_metrics,
"pair_metrics": pairwise,
}
fold_results.append(fold_result)
all_records.extend(records)
if eval_dir is not None:
fold_dir = eval_dir / f"fold{fold_idx}"
save_preds(records, fold_dir / "preds.csv")
save_errors(records, fold_dir / "errors.json")
print(f"\nFold {fold_idx + 1} Test Metrics:")
for key, value in test_metrics.items():
if key != "confusion_matrix":
print(f" {key}: {value}")
print(f" Per-source AUC:")
for source, sm in sorted(src_metrics.items()):
pa = sm.get("pairwise_auc")
dr = sm.get("detection_rate")
label = f"pairwise_auc={pa:.4f}" if pa is not None else f"detection_rate={dr:.4f}" if dr is not None else ""
print(f" {source}: {label}")
# Aggregate metrics across folds
test_metrics_list = [f["test_metrics"] for f in fold_results]
aggregated = aggregate_fold_metrics(test_metrics_list)
# Aggregate per-source metrics across folds
all_sources = sorted({s for f in fold_results for s in f["source_metrics"]})
aggregated_per_source = {}
for source in all_sources:
source_fold_metrics = []
for f in fold_results:
sm = f["source_metrics"].get(source)
if sm:
# Only keep scalar numeric fields for aggregation
source_fold_metrics.append({
k: v for k, v in sm.items()
if isinstance(v, (int, float)) and k != "fold"
})
if source_fold_metrics:
aggregated_per_source[source] = aggregate_fold_metrics(source_fold_metrics)
# Aggregate pairwise source metrics across folds
all_pairs = sorted({p for f in fold_results for p in f["pair_metrics"]})
aggregated_pairwise = {}
for pair in all_pairs:
pair_fold_metrics = []
for f in fold_results:
pm = f["pair_metrics"].get(pair)
if pm:
pair_fold_metrics.append({
k: v for k, v in pm.items()
if isinstance(v, (int, float)) and k not in ("fold", "n")
})
if pair_fold_metrics:
aggregated_pairwise[pair] = aggregate_fold_metrics(pair_fold_metrics)
results = {
"run_name": run_name,
"n_folds": len(splits),
"fold_results": fold_results,
"aggregated_metrics": aggregated,
"aggregated_per_source": aggregated_per_source,
"aggregated_pairwise": aggregated_pairwise,
}
if eval_dir is not None:
save_hists(all_records, eval_dir / "hists")
save_summary(results, eval_dir / "summary.json")
print(f"\n{'='*60}")
print("Cross-Validation Results (Aggregated)")
print(f"{'='*60}")
for key, value in aggregated.items():
print(f" {key}:")
print(f" mean: {value['mean']:.4f}")
print(f" std: {value['std']:.4f}")
print(f" 95% CI: ±{value['ci_95']:.4f}")
if aggregated_per_source:
print(f"\nPer-Source Pairwise AUC (wiki vs. fake source):")
for source in sorted(aggregated_per_source):
ps = aggregated_per_source[source]
pa = ps.get("pairwise_auc", {})
if pa:
print(f" {source}: {pa['mean']:.4f} ± {pa['std']:.4f}")
dr = ps.get("detection_rate")
if dr and not pa:
print(f" {source}: detection_rate={dr['mean']:.4f} ± {dr['std']:.4f}")
return results