Clean state
This commit is contained in:
@@ -0,0 +1,3 @@
|
||||
from src.training.trainer import train_classifier, train_classifier_cv
|
||||
|
||||
__all__ = ["train_classifier", "train_classifier_cv"]
|
||||
@@ -0,0 +1,374 @@
|
||||
import json
|
||||
from collections import Counter
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader, Subset
|
||||
from tqdm import tqdm
|
||||
|
||||
from src.evaluation.metrics import binary_metrics
|
||||
|
||||
# ── AMP compatibility shim ─────────────────────────────────────────────────
|
||||
|
||||
# torch.amp.GradScaler / autocast moved from torch.cuda.amp in PyTorch 2.3+
|
||||
if hasattr(torch.amp, "GradScaler"):
|
||||
_GradScaler = torch.amp.GradScaler
|
||||
_autocast = torch.amp.autocast
|
||||
else:
|
||||
from torch.cuda.amp import GradScaler as _OldGradScaler, autocast as _OldAutocast
|
||||
_GradScaler = lambda device="", enabled=True, **kw: _OldGradScaler(enabled=enabled, **kw)
|
||||
_autocast = lambda device_type="", enabled=True, **kw: _OldAutocast(enabled=enabled, **kw)
|
||||
|
||||
|
||||
# ── Single-fold training ───────────────────────────────────────────────────
|
||||
|
||||
# Trains one fold; saves best checkpoint by val AUC-ROC and final checkpoint.
|
||||
# pos_weight is passed through to BCEWithLogitsLoss to handle class imbalance.
|
||||
def train_classifier(
|
||||
model,
|
||||
train_dataset,
|
||||
val_dataset,
|
||||
*,
|
||||
epochs=10,
|
||||
batch_size=16,
|
||||
lr=1e-4,
|
||||
weight_decay=1e-4,
|
||||
device="cuda",
|
||||
save_dir="outputs/models",
|
||||
run_name="classifier",
|
||||
early_stopping_patience=0,
|
||||
num_workers=4,
|
||||
grad_clip_norm=1.0,
|
||||
T_max=None,
|
||||
pos_weight=None,
|
||||
):
|
||||
device = torch.device(device)
|
||||
if device.type == "cuda":
|
||||
print(f"Using GPU: {torch.cuda.get_device_name(0)}")
|
||||
else:
|
||||
print("Using CPU")
|
||||
|
||||
use_amp = device.type == "cuda"
|
||||
|
||||
model = model.to(device)
|
||||
|
||||
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
print(f"Trainable parameters: {n_params:,}")
|
||||
|
||||
train_loader = DataLoader(
|
||||
train_dataset, batch_size=batch_size, shuffle=True,
|
||||
num_workers=num_workers, pin_memory=True
|
||||
)
|
||||
val_loader = DataLoader(
|
||||
val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True
|
||||
)
|
||||
|
||||
pw = torch.tensor([pos_weight], device=device) if pos_weight is not None else None
|
||||
criterion = nn.BCEWithLogitsLoss(pos_weight=pw)
|
||||
optimizer = torch.optim.AdamW(
|
||||
filter(lambda p: p.requires_grad, model.parameters()),
|
||||
lr=lr, weight_decay=weight_decay,
|
||||
)
|
||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=T_max or epochs)
|
||||
scaler = _GradScaler("cuda", enabled=use_amp)
|
||||
print(f"Device: {device} AMP: {use_amp}")
|
||||
|
||||
save_dir = Path(save_dir)
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
history = {
|
||||
"train_loss": [], "train_acc": [], "train_auc": [], "train_f1": [],
|
||||
"val_loss": [], "val_acc": [], "val_auc": [], "val_f1": [],
|
||||
}
|
||||
best_auc = 0.0
|
||||
patience_counter = 0
|
||||
|
||||
for epoch in range(1, epochs + 1):
|
||||
# ── train ──
|
||||
model.train()
|
||||
total_loss = 0.0
|
||||
train_logits, train_labels = [], []
|
||||
for images, labels in tqdm(train_loader, desc=f"Epoch {epoch}/{epochs} [train]", leave=False):
|
||||
images = images.to(device)
|
||||
labels = labels.float().to(device)
|
||||
|
||||
optimizer.zero_grad()
|
||||
with _autocast("cuda", enabled=use_amp):
|
||||
logits = model(images).squeeze(1)
|
||||
loss = criterion(logits, labels)
|
||||
|
||||
scaler.scale(loss).backward()
|
||||
scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip_norm)
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
total_loss += loss.item() * len(images)
|
||||
train_logits.append(logits.detach().cpu())
|
||||
train_labels.append(labels.detach().cpu())
|
||||
|
||||
train_loss = total_loss / sum(len(logit) for logit in train_logits)
|
||||
train_m = binary_metrics(torch.cat(train_logits), torch.cat(train_labels))
|
||||
scheduler.step()
|
||||
|
||||
# ── validate ──
|
||||
model.eval()
|
||||
val_loss = 0.0
|
||||
all_logits, all_labels = [], []
|
||||
with torch.no_grad():
|
||||
for images, labels in tqdm(val_loader, desc=f"Epoch {epoch}/{epochs} [val]", leave=False):
|
||||
images = images.to(device)
|
||||
labels = labels.float().to(device)
|
||||
with _autocast("cuda", enabled=use_amp):
|
||||
logits = model(images).squeeze(1)
|
||||
batch_loss = criterion(logits, labels)
|
||||
if not (torch.isnan(batch_loss) or torch.isinf(batch_loss)):
|
||||
val_loss += batch_loss.item() * len(images)
|
||||
all_logits.append(logits.cpu())
|
||||
all_labels.append(labels.cpu())
|
||||
|
||||
val_loss /= len(val_dataset)
|
||||
val_m = binary_metrics(torch.cat(all_logits), torch.cat(all_labels))
|
||||
|
||||
# ── record ──
|
||||
history["train_loss"].append(train_loss)
|
||||
history["train_acc"].append(train_m["accuracy"])
|
||||
history["train_auc"].append(train_m["auc_roc"])
|
||||
history["train_f1"].append(train_m["f1"])
|
||||
history["val_loss"].append(val_loss)
|
||||
history["val_acc"].append(val_m["accuracy"])
|
||||
history["val_auc"].append(val_m["auc_roc"])
|
||||
history["val_f1"].append(val_m["f1"])
|
||||
|
||||
gap_loss = train_loss - val_loss
|
||||
gap_acc = train_m["accuracy"] - val_m["accuracy"]
|
||||
print(
|
||||
f"[{epoch:03d}/{epochs}] "
|
||||
f"loss: {train_loss:.4f}/{val_loss:.4f} (gap {gap_loss:+.4f}) "
|
||||
f"acc: {train_m['accuracy']:.4f}/{val_m['accuracy']:.4f} (gap {gap_acc:+.4f}) "
|
||||
f"auc: {train_m['auc_roc']:.4f}/{val_m['auc_roc']:.4f} "
|
||||
f"f1: {train_m['f1']:.4f}/{val_m['f1']:.4f}"
|
||||
)
|
||||
|
||||
# ── checkpoint ──
|
||||
if val_m["auc_roc"] is not None and val_m["auc_roc"] > best_auc:
|
||||
best_auc = val_m["auc_roc"]
|
||||
torch.save(model.state_dict(), save_dir / f"{run_name}_best.pt")
|
||||
patience_counter = 0
|
||||
else:
|
||||
patience_counter += 1
|
||||
|
||||
if early_stopping_patience > 0 and patience_counter >= early_stopping_patience:
|
||||
print(f"Early stopping at epoch {epoch} (no improvement for {early_stopping_patience} epochs)")
|
||||
break
|
||||
|
||||
torch.save(model.state_dict(), save_dir / f"{run_name}_final.pt")
|
||||
return history
|
||||
|
||||
|
||||
# ── CV training ────────────────────────────────────────────────────────────
|
||||
|
||||
# Iterates over pre-built splits, trains one model per fold, evaluates on the
|
||||
# held-out test fold, then aggregates metrics across folds with mean ± std
|
||||
def train_classifier_cv(
|
||||
model_fn,
|
||||
raw_dataset,
|
||||
splits,
|
||||
*,
|
||||
epochs=10,
|
||||
batch_size=16,
|
||||
lr=1e-4,
|
||||
weight_decay=1e-4,
|
||||
device="cuda",
|
||||
save_dir="outputs/models",
|
||||
run_name="classifier_cv",
|
||||
early_stopping_patience=0,
|
||||
num_workers=4,
|
||||
transform_builder=None,
|
||||
grad_clip_norm=1.0,
|
||||
T_max=None,
|
||||
normalization=None,
|
||||
logs_dir=None,
|
||||
):
|
||||
from src.evaluation.evaluate import (
|
||||
predict_rows, save_errors, save_hists, save_preds, save_summary,
|
||||
)
|
||||
from src.evaluation.metrics import binary_metrics, calc_metrics, source_metrics, pair_metrics
|
||||
from src.utils.cross_validation import aggregate_fold_metrics
|
||||
|
||||
device = torch.device(device if torch.cuda.is_available() else "cpu")
|
||||
save_dir = Path(save_dir)
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
eval_dir = Path(logs_dir) / run_name if logs_dir is not None else None
|
||||
if eval_dir is not None:
|
||||
eval_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
fold_results = []
|
||||
all_records = []
|
||||
|
||||
for fold_idx, (train_idx, val_idx, test_idx) in enumerate(splits):
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Fold {fold_idx + 1}/{len(splits)}")
|
||||
print(f" Train: {len(train_idx)} Val: {len(val_idx)} Test: {len(test_idx)}")
|
||||
print(f"{'='*60}")
|
||||
|
||||
model = model_fn().to(device)
|
||||
|
||||
norm_mean = norm_std = None
|
||||
if normalization == "real_norm":
|
||||
from src.preprocessing.pipeline import compute_real_stats
|
||||
norm_mean, norm_std = compute_real_stats(raw_dataset, train_idx)
|
||||
print(f" Real-norm stats: mean={norm_mean}, std={norm_std}")
|
||||
|
||||
if transform_builder is not None:
|
||||
train_dataset = transform_builder(train_idx, train=True,
|
||||
normalize_mean=norm_mean, normalize_std=norm_std)
|
||||
val_dataset = transform_builder(val_idx, train=False,
|
||||
normalize_mean=norm_mean, normalize_std=norm_std)
|
||||
test_dataset = transform_builder(test_idx, train=False,
|
||||
normalize_mean=norm_mean, normalize_std=norm_std)
|
||||
else:
|
||||
train_dataset = Subset(raw_dataset, train_idx)
|
||||
val_dataset = Subset(raw_dataset, val_idx)
|
||||
test_dataset = Subset(raw_dataset, test_idx)
|
||||
|
||||
# Compute pos_weight = n_real / n_fake for BCEWithLogitsLoss class balancing
|
||||
train_labels = [raw_dataset.samples[i][1] for i in train_idx]
|
||||
class_counts = Counter(train_labels)
|
||||
pos_weight = class_counts[0] / class_counts[1] if class_counts[1] > 0 else 1.0
|
||||
|
||||
fold_run_name = f"{run_name}_fold{fold_idx}"
|
||||
history = train_classifier(
|
||||
model,
|
||||
train_dataset,
|
||||
val_dataset,
|
||||
epochs=epochs,
|
||||
batch_size=batch_size,
|
||||
lr=lr,
|
||||
weight_decay=weight_decay,
|
||||
device=device,
|
||||
save_dir=save_dir,
|
||||
run_name=fold_run_name,
|
||||
early_stopping_patience=early_stopping_patience,
|
||||
num_workers=num_workers,
|
||||
grad_clip_norm=grad_clip_norm,
|
||||
T_max=T_max,
|
||||
pos_weight=pos_weight,
|
||||
)
|
||||
|
||||
# Load best checkpoint and evaluate on test set
|
||||
checkpoint_path = save_dir / f"{fold_run_name}_best.pt"
|
||||
if checkpoint_path.exists():
|
||||
model.load_state_dict(torch.load(checkpoint_path, map_location=device, weights_only=True))
|
||||
|
||||
model.eval()
|
||||
|
||||
records = predict_rows(
|
||||
model, test_dataset, raw_dataset, test_idx,
|
||||
batch_size, device, num_workers=num_workers,
|
||||
)
|
||||
|
||||
# Compute aggregate and per-source test metrics
|
||||
test_metrics = calc_metrics(records)
|
||||
src_metrics = source_metrics(records)
|
||||
pairwise = pair_metrics(records)
|
||||
|
||||
fold_result = {
|
||||
"fold": fold_idx,
|
||||
"train_size": len(train_idx),
|
||||
"val_size": len(val_idx),
|
||||
"test_size": len(test_idx),
|
||||
"history": history,
|
||||
"test_metrics": test_metrics,
|
||||
"source_metrics": src_metrics,
|
||||
"pair_metrics": pairwise,
|
||||
}
|
||||
fold_results.append(fold_result)
|
||||
all_records.extend(records)
|
||||
|
||||
if eval_dir is not None:
|
||||
fold_dir = eval_dir / f"fold{fold_idx}"
|
||||
save_preds(records, fold_dir / "preds.csv")
|
||||
save_errors(records, fold_dir / "errors.json")
|
||||
|
||||
print(f"\nFold {fold_idx + 1} Test Metrics:")
|
||||
for key, value in test_metrics.items():
|
||||
if key != "confusion_matrix":
|
||||
print(f" {key}: {value}")
|
||||
print(f" Per-source AUC:")
|
||||
for source, sm in sorted(src_metrics.items()):
|
||||
pa = sm.get("pairwise_auc")
|
||||
dr = sm.get("detection_rate")
|
||||
label = f"pairwise_auc={pa:.4f}" if pa is not None else f"detection_rate={dr:.4f}" if dr is not None else ""
|
||||
print(f" {source}: {label}")
|
||||
|
||||
# Aggregate metrics across folds
|
||||
test_metrics_list = [f["test_metrics"] for f in fold_results]
|
||||
aggregated = aggregate_fold_metrics(test_metrics_list)
|
||||
|
||||
# Aggregate per-source metrics across folds
|
||||
all_sources = sorted({s for f in fold_results for s in f["source_metrics"]})
|
||||
aggregated_per_source = {}
|
||||
for source in all_sources:
|
||||
source_fold_metrics = []
|
||||
for f in fold_results:
|
||||
sm = f["source_metrics"].get(source)
|
||||
if sm:
|
||||
# Only keep scalar numeric fields for aggregation
|
||||
source_fold_metrics.append({
|
||||
k: v for k, v in sm.items()
|
||||
if isinstance(v, (int, float)) and k != "fold"
|
||||
})
|
||||
if source_fold_metrics:
|
||||
aggregated_per_source[source] = aggregate_fold_metrics(source_fold_metrics)
|
||||
|
||||
# Aggregate pairwise source metrics across folds
|
||||
all_pairs = sorted({p for f in fold_results for p in f["pair_metrics"]})
|
||||
aggregated_pairwise = {}
|
||||
for pair in all_pairs:
|
||||
pair_fold_metrics = []
|
||||
for f in fold_results:
|
||||
pm = f["pair_metrics"].get(pair)
|
||||
if pm:
|
||||
pair_fold_metrics.append({
|
||||
k: v for k, v in pm.items()
|
||||
if isinstance(v, (int, float)) and k not in ("fold", "n")
|
||||
})
|
||||
if pair_fold_metrics:
|
||||
aggregated_pairwise[pair] = aggregate_fold_metrics(pair_fold_metrics)
|
||||
|
||||
results = {
|
||||
"run_name": run_name,
|
||||
"n_folds": len(splits),
|
||||
"fold_results": fold_results,
|
||||
"aggregated_metrics": aggregated,
|
||||
"aggregated_per_source": aggregated_per_source,
|
||||
"aggregated_pairwise": aggregated_pairwise,
|
||||
}
|
||||
|
||||
if eval_dir is not None:
|
||||
save_hists(all_records, eval_dir / "hists")
|
||||
save_summary(results, eval_dir / "summary.json")
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print("Cross-Validation Results (Aggregated)")
|
||||
print(f"{'='*60}")
|
||||
for key, value in aggregated.items():
|
||||
print(f" {key}:")
|
||||
print(f" mean: {value['mean']:.4f}")
|
||||
print(f" std: {value['std']:.4f}")
|
||||
print(f" 95% CI: ±{value['ci_95']:.4f}")
|
||||
|
||||
if aggregated_per_source:
|
||||
print(f"\nPer-Source Pairwise AUC (wiki vs. fake source):")
|
||||
for source in sorted(aggregated_per_source):
|
||||
ps = aggregated_per_source[source]
|
||||
pa = ps.get("pairwise_auc", {})
|
||||
if pa:
|
||||
print(f" {source}: {pa['mean']:.4f} ± {pa['std']:.4f}")
|
||||
dr = ps.get("detection_rate")
|
||||
if dr and not pa:
|
||||
print(f" {source}: detection_rate={dr['mean']:.4f} ± {dr['std']:.4f}")
|
||||
|
||||
return results
|
||||
Reference in New Issue
Block a user