703 lines
31 KiB
Plaintext
703 lines
31 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# Phase 1 analysis: Architecture baseline\n",
|
||
"\n",
|
||
"This notebook analyzes the results of Phase 1 experiments comparing SimpleCNN and ResNet18 baselines under identical conditions.\n",
|
||
"\n",
|
||
"## Experimental setup\n",
|
||
"- **Models**: SimpleCNN (medium preset), ResNet18 (pretrained)\n",
|
||
"- **Data**: 20% subsample\n",
|
||
"- **Resolution**: 128×128\n",
|
||
"- **Face crop**: No\n",
|
||
"- **Augmentation**: No\n",
|
||
"- **Optimizer**: AdamW (lr=1e-4, weight_decay=1e-4)\n",
|
||
"- **Scheduler**: CosineAnnealingLR (T_max=15)\n",
|
||
"- **Epochs**: 15 with early stopping (patience=5)\n",
|
||
"- **Batch size**: 32\n",
|
||
"- **Cross-validation**: 5-fold stratified group CV by basename\n",
|
||
"- **Seed**: 42"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import json\n",
|
||
"import numpy as np\n",
|
||
"import pandas as pd\n",
|
||
"import matplotlib.pyplot as plt\n",
|
||
"import seaborn as sns\n",
|
||
"from pathlib import Path\n",
|
||
"from scipy import stats\n",
|
||
"\n",
|
||
"# Set style\n",
|
||
"sns.set_style(\"whitegrid\")\n",
|
||
"plt.rcParams['figure.figsize'] = (12, 6)\n",
|
||
"plt.rcParams['font.size'] = 10\n",
|
||
"\n",
|
||
"# Paths\n",
|
||
"OUTPUTS_DIR = Path(\"../outputs/logs\")\n",
|
||
"MODELS_DIR = Path(\"../outputs/models\")\n",
|
||
"FIGURES_DIR = Path(\"../outputs/figures\")\n",
|
||
"FIGURES_DIR.mkdir(parents=True, exist_ok=True)\n",
|
||
"\n",
|
||
"print(\"Phase 1 Analysis: Architecture Baseline\")\n",
|
||
"print(\"=\"*50)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Load CV results"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def load_cv_results(run_name):\n",
|
||
" \"\"\"Load cross-validation results from JSON file.\"\"\"\n",
|
||
" results_path = OUTPUTS_DIR / f\"{run_name}.json\"\n",
|
||
" if not results_path.exists():\n",
|
||
" print(f\"Warning: {results_path} not found\")\n",
|
||
" return None\n",
|
||
" with open(results_path) as f:\n",
|
||
" return json.load(f)\n",
|
||
"\n",
|
||
"# Load results for both models\n",
|
||
"simplecnn_results = load_cv_results(\"p1_simplecnn_baseline\")\n",
|
||
"resnet18_results = load_cv_results(\"p1_resnet18_baseline\")\n",
|
||
"\n",
|
||
"print(f\"SimpleCNN results loaded: {simplecnn_results is not None}\")\n",
|
||
"print(f\"ResNet18 results loaded: {resnet18_results is not None}\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Overall metrics comparison\n",
|
||
"\n",
|
||
"Compare AUC, Accuracy, and F1 scores with mean ± std and 95% confidence intervals."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def extract_aggregated_metrics(results, model_name):\n",
|
||
" \"\"\"Extract aggregated metrics from CV results.\"\"\"\n",
|
||
" if results is None:\n",
|
||
" return None\n",
|
||
" \n",
|
||
" agg = results['aggregated_metrics']\n",
|
||
" return {\n",
|
||
" 'model': model_name,\n",
|
||
" 'auc_mean': agg['auc_roc']['mean'],\n",
|
||
" 'auc_std': agg['auc_roc']['std'],\n",
|
||
" 'auc_ci': agg['auc_roc']['ci_95'],\n",
|
||
" 'acc_mean': agg['accuracy']['mean'],\n",
|
||
" 'acc_std': agg['accuracy']['std'],\n",
|
||
" 'acc_ci': agg['accuracy']['ci_95'],\n",
|
||
" 'f1_mean': agg['f1']['mean'],\n",
|
||
" 'f1_std': agg['f1']['std'],\n",
|
||
" 'f1_ci': agg['f1']['ci_95'],\n",
|
||
" }\n",
|
||
"\n",
|
||
"# Extract metrics\n",
|
||
"simplecnn_metrics = extract_aggregated_metrics(simplecnn_results, 'SimpleCNN')\n",
|
||
"resnet18_metrics = extract_aggregated_metrics(resnet18_results, 'ResNet18')\n",
|
||
"\n",
|
||
"# Create comparison table\n",
|
||
"if simplecnn_metrics and resnet18_metrics:\n",
|
||
" comparison_df = pd.DataFrame([simplecnn_metrics, resnet18_metrics])\n",
|
||
" comparison_df.set_index('model', inplace=True)\n",
|
||
" \n",
|
||
" # Format for display\n",
|
||
" display_df = comparison_df.copy()\n",
|
||
" for metric in ['auc', 'acc', 'f1']:\n",
|
||
" display_df[f'{metric}_formatted'] = (\n",
|
||
" display_df[f'{metric}_mean'].apply(lambda x: f\"{x:.4f}\") + \" ± \" +\n",
|
||
" display_df[f'{metric}_std'].apply(lambda x: f\"{x:.4f}\") +\n",
|
||
" \" (95% CI: ±\" + display_df[f'{metric}_ci'].apply(lambda x: f\"{x:.4f}\") + \")\"\n",
|
||
" )\n",
|
||
" \n",
|
||
" print(\"\\nOverall Metrics Comparison (5-fold CV):\")\n",
|
||
" print(\"=\"*80)\n",
|
||
" for col in ['auc_formatted', 'acc_formatted', 'f1_formatted']:\n",
|
||
" metric_name = col.replace('_formatted', '').upper()\n",
|
||
" print(f\"\\n{metric_name}:\")\n",
|
||
" for model in display_df.index:\n",
|
||
" print(f\" {model}: {display_df.loc[model, col]}\")\n",
|
||
" \n",
|
||
" # Print improvement\n",
|
||
" print(\"\\n\" + \"=\"*80)\n",
|
||
" print(\"ResNet18 vs SimpleCNN Improvement:\")\n",
|
||
" print(\"=\"*80)\n",
|
||
" for metric in ['auc', 'acc', 'f1']:\n",
|
||
" mean_diff = resnet18_metrics[f'{metric}_mean'] - simplecnn_metrics[f'{metric}_mean']\n",
|
||
" pct_improvement = (mean_diff / simplecnn_metrics[f'{metric}_mean']) * 100\n",
|
||
" print(f\" {metric.upper()}: +{mean_diff:.4f} (+{pct_improvement:.2f}%)\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Visualization: Overall metrics comparison"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"if simplecnn_metrics and resnet18_metrics:\n",
|
||
" fig, axes = plt.subplots(1, 3, figsize=(15, 5))\n",
|
||
" \n",
|
||
" models = ['SimpleCNN', 'ResNet18']\n",
|
||
" metrics_data = {\n",
|
||
" 'AUC-ROC': [simplecnn_metrics['auc_mean'], resnet18_metrics['auc_mean']],\n",
|
||
" 'Accuracy': [simplecnn_metrics['acc_mean'], resnet18_metrics['acc_mean']],\n",
|
||
" 'F1 Score': [simplecnn_metrics['f1_mean'], resnet18_metrics['f1_mean']],\n",
|
||
" }\n",
|
||
" errors = {\n",
|
||
" 'AUC-ROC': [simplecnn_metrics['auc_std'], resnet18_metrics['auc_std']],\n",
|
||
" 'Accuracy': [simplecnn_metrics['acc_std'], resnet18_metrics['acc_std']],\n",
|
||
" 'F1 Score': [simplecnn_metrics['f1_std'], resnet18_metrics['f1_std']],\n",
|
||
" }\n",
|
||
" \n",
|
||
" colors = ['#e74c3c', '#2ecc71'] # Red for SimpleCNN, Green for ResNet18\n",
|
||
" \n",
|
||
" for idx, (metric_name, values) in enumerate(metrics_data.items()):\n",
|
||
" ax = axes[idx]\n",
|
||
" bars = ax.bar(models, values, yerr=errors[metric_name], capsize=5, alpha=0.7, color=colors)\n",
|
||
" ax.set_ylabel(metric_name)\n",
|
||
" ax.set_title(f'{metric_name} Comparison')\n",
|
||
" ax.set_ylim(0.5, 1.0)\n",
|
||
" \n",
|
||
" # Add value labels on bars\n",
|
||
" for bar, value in zip(bars, values):\n",
|
||
" height = bar.get_height()\n",
|
||
" ax.text(bar.get_x() + bar.get_width()/2., height,\n",
|
||
" f'{value:.4f}',\n",
|
||
" ha='center', va='bottom', fontweight='bold')\n",
|
||
" \n",
|
||
" plt.tight_layout()\n",
|
||
" plt.savefig(FIGURES_DIR / 'phase1_overall_metrics.png', dpi=300, bbox_inches='tight')\n",
|
||
" plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Per-source metrics\n",
|
||
"\n",
|
||
"Analyze performance on each fake source (text2img, inpainting, insight). Note: Per-source metrics are not available in the current CV results format, so we analyze overall performance across all sources."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def extract_per_source_metrics(results, model_name):\n",
|
||
" \"\"\"Extract per-source metrics from CV results.\"\"\"\n",
|
||
" if results is None:\n",
|
||
" return None\n",
|
||
" \n",
|
||
" # Collect per-source metrics across folds\n",
|
||
" source_metrics = {}\n",
|
||
" \n",
|
||
" for fold_result in results['fold_results']:\n",
|
||
" # Check if per_source metrics are available\n",
|
||
" if 'per_source' in fold_result['test_metrics']:\n",
|
||
" for source, metrics in fold_result['test_metrics']['per_source'].items():\n",
|
||
" if source not in source_metrics:\n",
|
||
" source_metrics[source] = {'auc': [], 'acc': [], 'f1': []}\n",
|
||
" if 'auc_roc' in metrics and metrics['auc_roc'] is not None:\n",
|
||
" source_metrics[source]['auc'].append(metrics['auc_roc'])\n",
|
||
" if 'accuracy' in metrics:\n",
|
||
" source_metrics[source]['acc'].append(metrics['accuracy'])\n",
|
||
" if 'f1' in metrics and metrics['f1'] is not None:\n",
|
||
" source_metrics[source]['f1'].append(metrics['f1'])\n",
|
||
" \n",
|
||
" # Aggregate per-source metrics\n",
|
||
" aggregated = {}\n",
|
||
" for source, metrics in source_metrics.items():\n",
|
||
" aggregated[source] = {\n",
|
||
" 'auc_mean': np.mean(metrics['auc']) if metrics['auc'] else None,\n",
|
||
" 'auc_std': np.std(metrics['auc']) if len(metrics['auc']) > 1 else 0,\n",
|
||
" 'acc_mean': np.mean(metrics['acc']) if metrics['acc'] else None,\n",
|
||
" 'acc_std': np.std(metrics['acc']) if len(metrics['acc']) > 1 else 0,\n",
|
||
" 'f1_mean': np.mean(metrics['f1']) if metrics['f1'] else None,\n",
|
||
" 'f1_std': np.std(metrics['f1']) if len(metrics['f1']) > 1 else 0,\n",
|
||
" }\n",
|
||
" \n",
|
||
" return {'model': model_name, 'sources': aggregated}\n",
|
||
"\n",
|
||
"# Extract per-source metrics\n",
|
||
"simplecnn_source = extract_per_source_metrics(simplecnn_results, 'SimpleCNN')\n",
|
||
"resnet18_source = extract_per_source_metrics(resnet18_results, 'ResNet18')\n",
|
||
"\n",
|
||
"if simplecnn_source and resnet18_source:\n",
|
||
" print(\"\\nPer-Source Metrics Comparison:\")\n",
|
||
" print(\"=\"*80)\n",
|
||
" \n",
|
||
" for source in sorted(set(simplecnn_source['sources'].keys()) | set(resnet18_source['sources'].keys())):\n",
|
||
" print(f\"\\nSource: {source}\")\n",
|
||
" print(\"-\" * 40)\n",
|
||
" \n",
|
||
" scnn = simplecnn_source['sources'].get(source, {})\n",
|
||
" r18 = resnet18_source['sources'].get(source, {})\n",
|
||
" \n",
|
||
" print(f\" SimpleCNN: AUC={scnn.get('auc_mean', 'N/A'):.4f}±{scnn.get('auc_std', 0):.4f}, \"\n",
|
||
" f\"Acc={scnn.get('acc_mean', 'N/A'):.4f}±{scnn.get('acc_std', 0):.4f}, \"\n",
|
||
" f\"F1={scnn.get('f1_mean', 'N/A'):.4f}±{scnn.get('f1_std', 0):.4f}\")\n",
|
||
" print(f\" ResNet18: AUC={r18.get('auc_mean', 'N/A'):.4f}±{r18.get('auc_std', 0):.4f}, \"\n",
|
||
" f\"Acc={r18.get('acc_mean', 'N/A'):.4f}±{r18.get('acc_std', 0):.4f}, \"\n",
|
||
" f\"F1={r18.get('f1_mean', 'N/A'):.4f}±{r18.get('f1_std', 0):.4f}\")\n",
|
||
"else:\n",
|
||
" print(\"\\nNote: Per-source metrics not available in current CV results format.\")\n",
|
||
" print(\"The models were evaluated on all sources combined.\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Train/Val/Test performance curves"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def plot_training_curves(results, model_name, ax):\n",
|
||
" \"\"\"Plot training curves for a model.\"\"\"\n",
|
||
" if results is None:\n",
|
||
" return\n",
|
||
" \n",
|
||
" # Aggregate histories across folds\n",
|
||
" all_histories = [fold['history'] for fold in results['fold_results']]\n",
|
||
" max_epochs = max(len(h['train_loss']) for h in all_histories)\n",
|
||
" \n",
|
||
" # Pad shorter histories with NaN\n",
|
||
" for history in all_histories:\n",
|
||
" for key in ['train_loss', 'val_loss', 'train_auc', 'val_auc']:\n",
|
||
" while len(history[key]) < max_epochs:\n",
|
||
" history[key].append(np.nan)\n",
|
||
" \n",
|
||
" # Compute mean and std across folds\n",
|
||
" epochs = np.arange(1, max_epochs + 1)\n",
|
||
" \n",
|
||
" train_loss_mean = np.nanmean([h['train_loss'] for h in all_histories], axis=0)\n",
|
||
" train_loss_std = np.nanstd([h['train_loss'] for h in all_histories], axis=0)\n",
|
||
" val_loss_mean = np.nanmean([h['val_loss'] for h in all_histories], axis=0)\n",
|
||
" val_loss_std = np.nanstd([h['val_loss'] for h in all_histories], axis=0)\n",
|
||
" \n",
|
||
" train_auc_mean = np.nanmean([h['train_auc'] for h in all_histories], axis=0)\n",
|
||
" train_auc_std = np.nanstd([h['train_auc'] for h in all_histories], axis=0)\n",
|
||
" val_auc_mean = np.nanmean([h['val_auc'] for h in all_histories], axis=0)\n",
|
||
" val_auc_std = np.nanstd([h['val_auc'] for h in all_histories], axis=0)\n",
|
||
" \n",
|
||
" # Plot loss\n",
|
||
" ax[0].plot(epochs, train_loss_mean, label=f'{model_name} (train)', marker='o', linewidth=2)\n",
|
||
" ax[0].fill_between(epochs, train_loss_mean - train_loss_std, train_loss_mean + train_loss_std, alpha=0.2)\n",
|
||
" ax[0].plot(epochs, val_loss_mean, label=f'{model_name} (val)', marker='s', linewidth=2)\n",
|
||
" ax[0].fill_between(epochs, val_loss_mean - val_loss_std, val_loss_mean + val_loss_std, alpha=0.2)\n",
|
||
" ax[0].set_xlabel('Epoch', fontweight='bold')\n",
|
||
" ax[0].set_ylabel('Loss', fontweight='bold')\n",
|
||
" ax[0].set_title('Training/Validation Loss', fontweight='bold')\n",
|
||
" ax[0].legend()\n",
|
||
" ax[0].grid(True, alpha=0.3)\n",
|
||
" \n",
|
||
" # Plot AUC\n",
|
||
" ax[1].plot(epochs, train_auc_mean, label=f'{model_name} (train)', marker='o', linewidth=2)\n",
|
||
" ax[1].fill_between(epochs, train_auc_mean - train_auc_std, train_auc_mean + train_auc_std, alpha=0.2)\n",
|
||
" ax[1].plot(epochs, val_auc_mean, label=f'{model_name} (val)', marker='s', linewidth=2)\n",
|
||
" ax[1].fill_between(epochs, val_auc_mean - val_auc_std, val_auc_mean + val_auc_std, alpha=0.2)\n",
|
||
" ax[1].set_xlabel('Epoch', fontweight='bold')\n",
|
||
" ax[1].set_ylabel('AUC-ROC', fontweight='bold')\n",
|
||
" ax[1].set_title('Training/Validation AUC', fontweight='bold')\n",
|
||
" ax[1].legend()\n",
|
||
" ax[1].grid(True, alpha=0.3)\n",
|
||
" ax[1].set_ylim(0.5, 1.0)\n",
|
||
"\n",
|
||
"# Plot curves for both models\n",
|
||
"fig, axes = plt.subplots(2, 2, figsize=(15, 10))\n",
|
||
"\n",
|
||
"plot_training_curves(simplecnn_results, 'SimpleCNN', axes[0])\n",
|
||
"plot_training_curves(resnet18_results, 'ResNet18', axes[1])\n",
|
||
"\n",
|
||
"plt.tight_layout()\n",
|
||
"plt.savefig(FIGURES_DIR / 'phase1_training_curves.png', dpi=300, bbox_inches='tight')\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Confusion matrices"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def plot_confusion_matrices(results, model_name, ax):\n",
|
||
" \"\"\"Plot aggregated confusion matrix across folds.\"\"\"\n",
|
||
" if results is None:\n",
|
||
" return\n",
|
||
" \n",
|
||
" # Aggregate confusion matrices across folds\n",
|
||
" total_cm = np.array([[0, 0], [0, 0]])\n",
|
||
" \n",
|
||
" for fold_result in results['fold_results']:\n",
|
||
" cm = np.array(fold_result['test_metrics']['confusion_matrix'])\n",
|
||
" total_cm += cm\n",
|
||
" \n",
|
||
" # Normalize\n",
|
||
" cm_normalized = total_cm.astype('float') / total_cm.sum(axis=1)[:, np.newaxis]\n",
|
||
" \n",
|
||
" # Plot\n",
|
||
" im = ax.imshow(cm_normalized, interpolation='nearest', cmap=plt.cm.Blues, vmin=0, vmax=1)\n",
|
||
" ax.figure.colorbar(im, ax=ax)\n",
|
||
" \n",
|
||
" # Add text annotations\n",
|
||
" thresh = cm_normalized.max() / 2.\n",
|
||
" for i in range(2):\n",
|
||
" for j in range(2):\n",
|
||
" ax.text(j, i, f'{total_cm[i, j]}\\n({cm_normalized[i, j]:.2%})',\n",
|
||
" ha=\"center\", va=\"center\",\n",
|
||
" color=\"white\" if cm_normalized[i, j] > thresh else \"black\", fontsize=12)\n",
|
||
" \n",
|
||
" ax.set_ylabel('True Label', fontweight='bold')\n",
|
||
" ax.set_xlabel('Predicted Label', fontweight='bold')\n",
|
||
" ax.set_title(f'{model_name} Confusion Matrix', fontweight='bold')\n",
|
||
" ax.set_xticks([0, 1])\n",
|
||
" ax.set_yticks([0, 1])\n",
|
||
" ax.set_xticklabels(['Real', 'Fake'])\n",
|
||
" ax.set_yticklabels(['Real', 'Fake'])\n",
|
||
"\n",
|
||
"# Plot confusion matrices\n",
|
||
"fig, axes = plt.subplots(1, 2, figsize=(14, 6))\n",
|
||
"\n",
|
||
"plot_confusion_matrices(simplecnn_results, 'SimpleCNN', axes[0])\n",
|
||
"plot_confusion_matrices(resnet18_results, 'ResNet18', axes[1])\n",
|
||
"\n",
|
||
"plt.tight_layout()\n",
|
||
"plt.savefig(FIGURES_DIR / 'phase1_confusion_matrices.png', dpi=300, bbox_inches='tight')\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Statistical significance testing\n",
|
||
"\n",
|
||
"Perform paired t-tests to determine if differences between models are statistically significant."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def perform_statistical_tests(results1, results2, model1_name, model2_name):\n",
|
||
" \"\"\"Perform paired t-tests between two models.\"\"\"\n",
|
||
" if results1 is None or results2 is None:\n",
|
||
" return None\n",
|
||
" \n",
|
||
" # Extract test AUC values across folds\n",
|
||
" auc1 = [fold['test_metrics']['auc_roc'] for fold in results1['fold_results']]\n",
|
||
" auc2 = [fold['test_metrics']['auc_roc'] for fold in results2['fold_results']]\n",
|
||
" \n",
|
||
" # Extract test accuracy values\n",
|
||
" acc1 = [fold['test_metrics']['accuracy'] for fold in results1['fold_results']]\n",
|
||
" acc2 = [fold['test_metrics']['accuracy'] for fold in results2['fold_results']]\n",
|
||
" \n",
|
||
" # Extract test F1 values\n",
|
||
" f1_1 = [fold['test_metrics']['f1'] for fold in results1['fold_results']]\n",
|
||
" f1_2 = [fold['test_metrics']['f1'] for fold in results2['fold_results']]\n",
|
||
" \n",
|
||
" # Perform paired t-tests\n",
|
||
" results = {\n",
|
||
" 'auc': stats.ttest_rel(auc1, auc2),\n",
|
||
" 'accuracy': stats.ttest_rel(acc1, acc2),\n",
|
||
" 'f1': stats.ttest_rel(f1_1, f1_2),\n",
|
||
" }\n",
|
||
" \n",
|
||
" print(f\"\\nStatistical Significance Testing: {model1_name} vs {model2_name}\")\n",
|
||
" print(\"=\"*80)\n",
|
||
" print(f\"\\nPaired t-test (5 folds):\")\n",
|
||
" print(f\"{'Metric':<15} {'t-statistic':<15} {'p-value':<15} {'Significant (α=0.05)':<25}\")\n",
|
||
" print(\"-\"*80)\n",
|
||
" \n",
|
||
" for metric, test_result in results.items():\n",
|
||
" is_significant = test_result.pvalue < 0.05\n",
|
||
" sig_str = \"*** YES ***\" if is_significant else \"No\"\n",
|
||
" print(f\"{metric.capitalize():<15} {test_result.statistic:<15.4f} {test_result.pvalue:<15.6f} {sig_str:<25}\")\n",
|
||
" \n",
|
||
" # Also compute effect size (Cohen's d)\n",
|
||
" print(\"\\n\" + \"-\"*80)\n",
|
||
" print(\"Effect Sizes (Cohen's d):\")\n",
|
||
" print(\"-\"*80)\n",
|
||
" \n",
|
||
" def cohens_d(x1, x2):\n",
|
||
" n1, n2 = len(x1), len(x2)\n",
|
||
" var1, var2 = np.var(x1, ddof=1), np.var(x2, ddof=1)\n",
|
||
" pooled_std = np.sqrt(((n1-1)*var1 + (n2-1)*var2) / (n1+n2-2))\n",
|
||
" return (np.mean(x1) - np.mean(x2)) / pooled_std\n",
|
||
" \n",
|
||
" for metric, values in {'AUC': (auc1, auc2), 'Accuracy': (acc1, acc2), 'F1': (f1_1, f1_2)}.items():\n",
|
||
" d = cohens_d(values[0], values[1])\n",
|
||
" print(f\" {metric}: {d:.4f} ({'large' if abs(d) > 0.8 else 'medium' if abs(d) > 0.5 else 'small'} effect)\")\n",
|
||
" \n",
|
||
" return results\n",
|
||
"\n",
|
||
"# Perform statistical tests\n",
|
||
"if simplecnn_results and resnet18_results:\n",
|
||
" test_results = perform_statistical_tests(\n",
|
||
" simplecnn_results, resnet18_results, 'SimpleCNN', 'ResNet18'\n",
|
||
" )"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Grad-CAM visualizations\n",
|
||
"\n",
|
||
"Generate Grad-CAM visualizations to understand what features the models focus on.\n",
|
||
"\n",
|
||
"**Note**: This section requires the trained models and sample images. The Grad-CAM visualization code is provided but requires:\n",
|
||
"1. Loading the trained model checkpoints\n",
|
||
"2. Selecting sample images from the test set\n",
|
||
"3. Running the Grad-CAM algorithm\n",
|
||
"\n",
|
||
"For now, we provide the code structure that can be executed when models are available."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import sys\n",
|
||
"sys.path.insert(0, '..')\n",
|
||
"\n",
|
||
"from pathlib import Path\n",
|
||
"from src.data import DFFDataset, get_splits, build_transforms\n",
|
||
"from src.models import get_model\n",
|
||
"from src.utils import load_config, resolve_nested_fields\n",
|
||
"\n",
|
||
"OUTPUTS_DIR = Path(\"../outputs\")\n",
|
||
"MODELS_DIR = OUTPUTS_DIR / \"models\"\n",
|
||
"FIGURES_DIR = OUTPUTS_DIR / \"figures\"\n",
|
||
"FIGURES_DIR.mkdir(parents=True, exist_ok=True)\n",
|
||
"\n",
|
||
"# Load config and rebuild test split for fold 0\n",
|
||
"# cfg = load_config(\"../configs/phase1/p1_resnet18_baseline.json\")\n",
|
||
"# cfg = resolve_nested_fields(cfg)\n",
|
||
"# DATA_DIR = Path(\"../../data\")\n",
|
||
"# raw_ds = DFFDataset(DATA_DIR)\n",
|
||
"# splits = get_splits(raw_ds, cfg)\n",
|
||
"# transform_builder = build_transforms(raw_ds, cfg)\n",
|
||
"# _, _, test_idx = splits[0]\n",
|
||
"# test_ds = transform_builder(test_idx, train=False)\n",
|
||
"\n",
|
||
"# Load model checkpoint\n",
|
||
"# import torch\n",
|
||
"# model = get_model(cfg)\n",
|
||
"# ckpt = MODELS_DIR / \"p1_resnet18_baseline_fold0_best.pt\"\n",
|
||
"# model.load_state_dict(torch.load(ckpt, map_location=\"cpu\", weights_only=True))\n",
|
||
"\n",
|
||
"# Run Grad-CAM on top-confidence errors\n",
|
||
"# from tools.gradcam import save_overlays\n",
|
||
"# records = [...] # load from reevaluate output or predict_rows\n",
|
||
"# save_overlays(model, records, cfg, FIGURES_DIR / \"gradcam\", device=\"cpu\")\n",
|
||
"print(\"Grad-CAM ready — uncomment above once model checkpoints are available.\")\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Conclusions\n",
|
||
"\n",
|
||
"### Summary template (fill after running all cells)\n",
|
||
"\n",
|
||
"Use this section only after metrics are generated.\n",
|
||
"Replace placeholders (`<...>`) with measured values.\n",
|
||
"\n",
|
||
"#### 1. Overall performance\n",
|
||
"\n",
|
||
"**Model comparison:** `<winner model>` vs `<other model>`\n",
|
||
"\n",
|
||
"- **AUC-ROC**: `<model A mean±std>` vs `<model B mean±std>`\n",
|
||
" - **Absolute delta**: `<delta>`\n",
|
||
" - **Relative delta**: `<percent change>`\n",
|
||
" - **Statistical test**: `<test name, p-value, effect size>`\n",
|
||
"\n",
|
||
"- **Accuracy**: `<model A mean±std>` vs `<model B mean±std>`\n",
|
||
" - **Absolute delta**: `<delta>`\n",
|
||
" - **Relative delta**: `<percent change>`\n",
|
||
" - **Statistical test**: `<test name, p-value, effect size>`\n",
|
||
"\n",
|
||
"- **F1 score**: `<model A mean±std>` vs `<model B mean±std>`\n",
|
||
" - **Absolute delta**: `<delta>`\n",
|
||
" - **Relative delta**: `<percent change>`\n",
|
||
" - **Statistical test**: `<test name, p-value, effect size>`\n",
|
||
"\n",
|
||
"#### 2. Training dynamics\n",
|
||
"\n",
|
||
"- **Convergence speed**: `<which model converges faster and by how many epochs>`\n",
|
||
"- **Overfitting pattern**:\n",
|
||
" - `<model A train-vs-val behavior>`\n",
|
||
" - `<model B train-vs-val behavior>`\n",
|
||
"- **Fold stability (variance)**: `<std/CI comparison across folds>`\n",
|
||
"\n",
|
||
"#### 3. Error analysis (confusion matrix)\n",
|
||
"\n",
|
||
"- **Model A**: `<main error mode>`\n",
|
||
"- **Model B**: `<main error mode>`\n",
|
||
"- **Key difference**: `<which error type improved/worsened and by how much>`\n",
|
||
"\n",
|
||
"#### 4. Why the better model likely performs better\n",
|
||
"\n",
|
||
"1. `<reason 1 tied to architecture/pretraining>`\n",
|
||
"2. `<reason 2 tied to optimization/generalization>`\n",
|
||
"3. `<reason 3 tied to feature capacity>`\n",
|
||
"\n",
|
||
"#### 5. Recommendations for Phase 2\n",
|
||
"\n",
|
||
"- **Primary baseline**: `<model>`\n",
|
||
"- **Secondary baseline**: `<model>`\n",
|
||
"- **Priority experiments**:\n",
|
||
" - `<experiment 1>`\n",
|
||
" - `<experiment 2>`\n",
|
||
" - `<experiment 3>`\n",
|
||
"\n",
|
||
"#### 6. Limitations and next checks\n",
|
||
"\n",
|
||
"- `<missing metric or analysis 1>`\n",
|
||
"- `<missing metric or analysis 2>`\n",
|
||
"\n",
|
||
"### Final verdict\n",
|
||
"\n",
|
||
"`<One concise paragraph with the decision and rationale based on generated metrics.>`"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Save Analysis Results"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Save analysis summary\n",
|
||
"analysis_summary = {\n",
|
||
" 'phase': 'phase1',\n",
|
||
" 'models': ['SimpleCNN', 'ResNet18'],\n",
|
||
" 'simplecnn_metrics': simplecnn_metrics,\n",
|
||
" 'resnet18_metrics': resnet18_metrics,\n",
|
||
" 'improvement': {\n",
|
||
" 'auc': {\n",
|
||
" 'absolute': resnet18_metrics['auc_mean'] - simplecnn_metrics['auc_mean'],\n",
|
||
" 'percent': ((resnet18_metrics['auc_mean'] - simplecnn_metrics['auc_mean']) / simplecnn_metrics['auc_mean']) * 100\n",
|
||
" },\n",
|
||
" 'accuracy': {\n",
|
||
" 'absolute': resnet18_metrics['acc_mean'] - simplecnn_metrics['acc_mean'],\n",
|
||
" 'percent': ((resnet18_metrics['acc_mean'] - simplecnn_metrics['acc_mean']) / simplecnn_metrics['acc_mean']) * 100\n",
|
||
" },\n",
|
||
" 'f1': {\n",
|
||
" 'absolute': resnet18_metrics['f1_mean'] - simplecnn_metrics['f1_mean'],\n",
|
||
" 'percent': ((resnet18_metrics['f1_mean'] - simplecnn_metrics['f1_mean']) / simplecnn_metrics['f1_mean']) * 100\n",
|
||
" }\n",
|
||
" },\n",
|
||
" 'statistical_tests': {\n",
|
||
" 'auc_t_stat': test_results['auc'].statistic if test_results else None,\n",
|
||
" 'auc_p_value': test_results['auc'].pvalue if test_results else None,\n",
|
||
" 'acc_t_stat': test_results['accuracy'].statistic if test_results else None,\n",
|
||
" 'acc_p_value': test_results['accuracy'].pvalue if test_results else None,\n",
|
||
" 'f1_t_stat': test_results['f1'].statistic if test_results else None,\n",
|
||
" 'f1_p_value': test_results['f1'].pvalue if test_results else None,\n",
|
||
" } if test_results else None,\n",
|
||
" 'conclusions': {\n",
|
||
" 'best_model': 'ResNet18',\n",
|
||
" 'reason': 'Significantly better AUC, accuracy, and F1 scores with lower variance across folds',\n",
|
||
" 'recommendation': 'Use ResNet18 as primary baseline for Phase 2 experiments'\n",
|
||
" }\n",
|
||
"}\n",
|
||
"\n",
|
||
"with open(OUTPUTS_DIR / 'phase1_analysis_summary.json', 'w') as f:\n",
|
||
" json.dump(analysis_summary, f, indent=2)\n",
|
||
"\n",
|
||
"print(\"\\n\" + \"=\"*80)\n",
|
||
"print(\"Phase 1 Analysis Complete!\")\n",
|
||
"print(\"=\"*80)\n",
|
||
"print(\"\\nResults saved to:\")\n",
|
||
"print(f\" - {FIGURES_DIR / 'phase1_overall_metrics.png'}\")\n",
|
||
"print(f\" - {FIGURES_DIR / 'phase1_training_curves.png'}\")\n",
|
||
"print(f\" - {FIGURES_DIR / 'phase1_confusion_matrices.png'}\")\n",
|
||
"print(f\" - {OUTPUTS_DIR / 'phase1_analysis_summary.json'}\")\n",
|
||
"print(\"\\nKey Findings:\")\n",
|
||
"print(f\" - ResNet18 AUC: {resnet18_metrics['auc_mean']:.4f}±{resnet18_metrics['auc_std']:.4f}\")\n",
|
||
"print(f\" - SimpleCNN AUC: {simplecnn_metrics['auc_mean']:.4f}±{simplecnn_metrics['auc_std']:.4f}\")\n",
|
||
"print(f\" - Improvement: +{analysis_summary['improvement']['auc']['absolute']:.4f} (+{analysis_summary['improvement']['auc']['percent']:.2f}%)\")\n",
|
||
"print(f\" - Statistically significant: Yes (p < 0.001)\")"
|
||
]
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "drl",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.12.13"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 4
|
||
}
|