CIFAR CHALLENGE SERIES - PART 5
Ensemble Methods: Combining Models for Maximum Accuracy
EDUSHARK TRAINING
35 min read
Our Progress
Previous: 97.52% | Target: 99.5%
Gap remaining: 1.98%
Ensemble methods combine predictions from multiple models to achieve better performance than any single model. The key insight: different models make different mistakes.
Conditions for Ensemble Success
- Accuracy: Each model should be reasonably accurate (better than random)
- Diversity: Models should make different errors (not correlated)
- Independence: Ideally, trained with different data/architectures/hyperparameters
Single Model: Ensemble of 5 Models:
Input Image Input Image
│ │
▼ ┌──────┼──────┐
┌─────────┐ │ │ │
│ Model │ ▼ ▼ ▼
│ 97.5% │ ┌─────┐ ┌─────┐ ┌─────┐
└────┬────┘ │M1 │ │M2 │ │M3 │
│ │97.5%│ │97.3%│ │97.1%│
▼ └──┬──┘ └──┬──┘ └──┬──┘
Prediction │ │ │
│ │ │
▼ ▼ ▼
┌───────────────────┐
│ Average/Vote │
│ Predictions │
└─────────┬─────────┘
│
▼
Final: 98.3%
Model Averaging
Average softmax probabilities
+0.3-0.8%
Snapshot Ensemble
Free ensemble from one training run
+0.2-0.5%
Test-Time Aug
Multiple views of same image
+0.2-0.4%
Architecture Mix
Different network designs
+0.3-0.6%
The simplest ensemble method: average the softmax probabilities (or logits) from multiple models, then take argmax.
ensemble/model_averaging.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List
class EnsembleModel(nn.Module):
"""
Ensemble of multiple models with prediction averaging.
Combines predictions from multiple models by averaging their
softmax probabilities (soft voting) or logits.
Args:
models: List of trained models
mode: 'soft' (average probabilities) or 'logit' (average logits)
"""
def __init__(self, models: List[nn.Module], mode='soft'):
super().__init__()
self.models = nn.ModuleList(models)
self.mode = mode
for model in self.models:
model.eval()
def forward(self, x):
"""Get averaged predictions from all models."""
with torch.no_grad():
if self.mode == 'soft':
probs = []
for model in self.models:
logits = model(x)
probs.append(F.softmax(logits, dim=-1))
avg_probs = torch.stack(probs).mean(dim=0)
return avg_probs
else:
logits = []
for model in self.models:
logits.append(model(x))
avg_logits = torch.stack(logits).mean(dim=0)
return avg_logits
def predict(self, x):
"""Get class predictions."""
output = self.forward(x)
return output.argmax(dim=-1)
class WeightedEnsemble(nn.Module):
"""
Weighted ensemble with learnable or fixed weights.
Different models contribute differently based on their strengths.
Args:
models: List of trained models
weights: Optional list of weights (default: uniform)
learnable: If True, weights are learnable parameters
"""
def __init__(
self,
models: List[nn.Module],
weights: List[float] = None,
learnable: bool = False
):
super().__init__()
self.models = nn.ModuleList(models)
self.num_models = len(models)
if weights is None:
weights = [1.0 / self.num_models] * self.num_models
if learnable:
self.weights = nn.Parameter(torch.tensor(weights))
else:
self.register_buffer('weights', torch.tensor(weights))
for model in self.models:
model.eval()
def forward(self, x):
"""Get weighted averaged predictions."""
normalized_weights = F.softmax(self.weights, dim=0)
with torch.no_grad():
probs = []
for i, model in enumerate(self.models):
logits = model(x)
prob = F.softmax(logits, dim=-1)
probs.append(normalized_weights[i] * prob)
return torch.stack(probs).sum(dim=0)
def evaluate_ensemble(models, test_loader, device):
"""Evaluate ensemble on test set."""
ensemble = EnsembleModel(models, mode='soft').to(device)
correct = 0
total = 0
with torch.no_grad():
for inputs, targets in test_loader:
inputs, targets = inputs.to(device), targets.to(device)
predictions = ensemble.predict(inputs)
correct += predictions.eq(targets).sum().item()
total += targets.size(0)
accuracy = 100 * correct / total
return accuracy
def find_optimal_weights(models, val_loader, device, num_iterations=100):
"""
Find optimal ensemble weights using grid search on validation set.
Returns weights that maximize validation accuracy.
"""
import numpy as np
from itertools import product
num_models = len(models)
weight_range = np.linspace(0.1, 1.0, 10)
best_acc = 0
best_weights = [1.0 / num_models] * num_models
if num_models <= 3:
for weights in product(weight_range, repeat=num_models):
weights = list(weights)
total = sum(weights)
weights = [w / total for w in weights]
ensemble = WeightedEnsemble(models, weights=weights).to(device)
acc = evaluate_single_ensemble(ensemble, val_loader, device)
if acc > best_acc:
best_acc = acc
best_weights = weights.copy()
else:
for _ in range(num_iterations):
weights = np.random.dirichlet(np.ones(num_models))
weights = weights.tolist()
ensemble = WeightedEnsemble(models, weights=weights).to(device)
acc = evaluate_single_ensemble(ensemble, val_loader, device)
if acc > best_acc:
best_acc = acc
best_weights = weights.copy()
return best_weights, best_acc
def evaluate_single_ensemble(ensemble, loader, device):
"""Helper to evaluate a single ensemble configuration."""
correct = 0
total = 0
with torch.no_grad():
for inputs, targets in loader:
inputs, targets = inputs.to(device), targets.to(device)
probs = ensemble(inputs)
predictions = probs.argmax(dim=-1)
correct += predictions.eq(targets).sum().item()
total += targets.size(0)
return 100 * correct / total
3-Model Ensemble (PyramidNet + WideResNet + DenseNet):
98.12%
Improvement from single model: +0.60%
Snapshot Ensembles (Huang et al., 2017) create an ensemble from a single training run using cyclic learning rates. At each cycle minimum, we save a "snapshot" of the model - these snapshots form our ensemble.
Why Snapshot Ensembles Work
The cyclic learning rate causes the model to explore different local minima. Each snapshot converges to a different minimum, providing natural diversity.
- Free ensemble: Only train once, get multiple models
- Diverse minima: Different weight configurations
- Same architecture: Easy to combine
Cyclic Learning Rate with Snapshots:
LR
│
│ ▓▓ ▓▓ ▓▓
│ ▓▓ ▓▓ ▓▓ ▓▓ ▓▓
│ ▓▓ ▓▓ ▓▓ ▓▓ ▓▓
│ ▓▓ ▓▓ ▓▓ ▓▓ ▓
│ ▓▓ ▓▓ ▓▓ ▓▓
│ ▓▓▓▓ ▓▓▓▓
│ * * *
└───────────────────────────────────────────────────────
Snapshot 1 Snapshot 2 Snapshot 3
* = Save model at cycle minimum (local optimum)
ensemble/snapshot.py
import torch
import torch.nn as nn
import torch.optim as optim
import math
from copy import deepcopy
class CyclicCosineScheduler:
"""
Cyclic cosine annealing scheduler for snapshot ensembles.
Learning rate follows cosine curve within each cycle,
restarting at max_lr at the beginning of each cycle.
Args:
optimizer: PyTorch optimizer
epochs_per_cycle: Epochs per restart cycle
num_cycles: Total number of cycles
min_lr: Minimum learning rate at cycle end
"""
def __init__(
self,
optimizer,
epochs_per_cycle,
num_cycles,
min_lr=1e-6
):
self.optimizer = optimizer
self.epochs_per_cycle = epochs_per_cycle
self.num_cycles = num_cycles
self.min_lr = min_lr
self.max_lr = optimizer.param_groups[0]['lr']
self.current_epoch = 0
def step(self):
"""Update learning rate for current epoch."""
cycle_position = (self.current_epoch % self.epochs_per_cycle) / self.epochs_per_cycle
lr = self.min_lr + 0.5 * (self.max_lr - self.min_lr) * (
1 + math.cos(math.pi * cycle_position)
)
for param_group in self.optimizer.param_groups:
param_group['lr'] = lr
self.current_epoch += 1
return lr
def is_cycle_end(self):
"""Check if we're at the end of a cycle (time to save snapshot)."""
return self.current_epoch % self.epochs_per_cycle == 0 and self.current_epoch > 0
def get_lr(self):
return self.optimizer.param_groups[0]['lr']
class SnapshotEnsembleTrainer:
"""
Trainer for Snapshot Ensembles.
Trains a single model with cyclic learning rate and saves
snapshots at each cycle minimum.
Paper: "Snapshot Ensembles: Train 1, get M for free"
Args:
model: Model to train
num_cycles: Number of cycles (= number of snapshots)
epochs_per_cycle: Training epochs per cycle
device: Training device
"""
def __init__(
self,
model,
num_cycles=5,
epochs_per_cycle=40,
device='cuda'
):
self.model = model.to(device)
self.num_cycles = num_cycles
self.epochs_per_cycle = epochs_per_cycle
self.device = device
self.snapshots = []
def train(
self,
train_loader,
test_loader,
base_lr=0.1,
weight_decay=5e-4
):
"""Train model and collect snapshots."""
total_epochs = self.num_cycles * self.epochs_per_cycle
optimizer = optim.SGD(
self.model.parameters(),
lr=base_lr,
momentum=0.9,
weight_decay=weight_decay,
nesterov=True
)
scheduler = CyclicCosineScheduler(
optimizer,
epochs_per_cycle=self.epochs_per_cycle,
num_cycles=self.num_cycles
)
criterion = nn.CrossEntropyLoss()
for epoch in range(total_epochs):
self.model.train()
for inputs, targets in train_loader:
inputs, targets = inputs.to(self.device), targets.to(self.device)
optimizer.zero_grad()
outputs = self.model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
lr = scheduler.step()
acc = self._evaluate(test_loader)
cycle_num = epoch // self.epochs_per_cycle + 1
print(f'Epoch {epoch+1}/{total_epochs} | LR: {lr:.6f} | Acc: {acc:.2f}% | Cycle: {cycle_num}')
if scheduler.is_cycle_end():
snapshot = deepcopy(self.model.state_dict())
self.snapshots.append(snapshot)
print(f' >> Saved snapshot {len(self.snapshots)} (Acc: {acc:.2f}%)')
return self.snapshots
def _evaluate(self, loader):
"""Evaluate model accuracy."""
self.model.eval()
correct = 0
total = 0
with torch.no_grad():
for inputs, targets in loader:
inputs, targets = inputs.to(self.device), targets.to(self.device)
outputs = self.model(inputs)
_, predicted = outputs.max(1)
correct += predicted.eq(targets).sum().item()
total += targets.size(0)
return 100 * correct / total
def get_ensemble(self, model_class, model_args):
"""
Create ensemble from saved snapshots.
Args:
model_class: Class to instantiate models
model_args: Arguments for model constructor
Returns:
EnsembleModel with all snapshots
"""
models = []
for i, snapshot in enumerate(self.snapshots):
model = model_class(**model_args)
model.load_state_dict(snapshot)
model.to(self.device)
model.eval()
models.append(model)
print(f'Loaded snapshot {i+1}')
return EnsembleModel(models, mode='soft')
def train_snapshot_ensemble():
from models.pyramidnet import pyramidnet110_a270
model = pyramidnet110_a270(num_classes=10)
trainer = SnapshotEnsembleTrainer(
model,
num_cycles=5,
epochs_per_cycle=40,
device='cuda'
)
snapshots = trainer.train(train_loader, test_loader)
ensemble = trainer.get_ensemble(
model_class=pyramidnet110_a270,
model_args={'num_classes': 10}
)
acc = evaluate_ensemble([ensemble], test_loader, device)
print(f'Snapshot Ensemble Accuracy: {acc:.2f}%')
5-Snapshot Ensemble (PyramidNet):
97.89%
Free +0.37% over single model!
Test-Time Augmentation applies augmentations to test images and averages predictions across all augmented versions. This is like creating an ensemble from different views of the same image.
ensemble/tta.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
from typing import List, Callable
class TTAWrapper(nn.Module):
"""
Test-Time Augmentation wrapper for any model.
Applies multiple augmentations to each test image and averages
the predictions.
Args:
model: Trained model
transforms: List of transforms to apply
merge_mode: 'mean' (average probs) or 'max' (max voting)
"""
def __init__(
self,
model: nn.Module,
transforms: List[Callable] = None,
merge_mode: str = 'mean'
):
super().__init__()
self.model = model
self.merge_mode = merge_mode
if transforms is None:
self.transforms = [
lambda x: x,
lambda x: torch.flip(x, [-1]),
]
else:
self.transforms = transforms
self.model.eval()
def forward(self, x):
"""Get averaged predictions across all transforms."""
all_probs = []
with torch.no_grad():
for transform in self.transforms:
augmented = transform(x)
logits = self.model(augmented)
probs = F.softmax(logits, dim=-1)
all_probs.append(probs)
stacked = torch.stack(all_probs, dim=0)
if self.merge_mode == 'mean':
return stacked.mean(dim=0)
elif self.merge_mode == 'max':
return stacked.max(dim=0)[0]
else:
raise ValueError(f'Unknown merge mode: {self.merge_mode}')
class AdvancedTTA(nn.Module):
"""
Advanced TTA with more augmentation options.
Includes:
- Horizontal flip
- Small crops at different positions
- Color variations
"""
def __init__(self, model, num_crops=5, use_flips=True):
super().__init__()
self.model = model
self.num_crops = num_crops
self.use_flips = use_flips
self.model.eval()
def _get_crops(self, x, crop_size=32, num_crops=5):
"""Get multiple crops from padded image."""
B, C, H, W = x.shape
pad = 4
padded = F.pad(x, [pad, pad, pad, pad], mode='reflect')
crops = [x]
if num_crops > 1:
positions = [
(0, 0),
(0, 2 * pad),
(2 * pad, 0),
(2 * pad, 2 * pad),
]
for i, (top, left) in enumerate(positions[:num_crops-1]):
crop = padded[:, :, top:top+crop_size, left:left+crop_size]
crops.append(crop)
return crops
def forward(self, x):
"""Get TTA predictions."""
all_probs = []
with torch.no_grad():
crops = self._get_crops(x, num_crops=self.num_crops)
for crop in crops:
logits = self.model(crop)
all_probs.append(F.softmax(logits, dim=-1))
if self.use_flips:
flipped = torch.flip(crop, [-1])
logits = self.model(flipped)
all_probs.append(F.softmax(logits, dim=-1))
stacked = torch.stack(all_probs, dim=0)
return stacked.mean(dim=0)
def evaluate_with_tta(model, test_loader, device, num_crops=5):
"""Evaluate model with test-time augmentation."""
tta_model = AdvancedTTA(model, num_crops=num_crops, use_flips=True).to(device)
correct = 0
total = 0
with torch.no_grad():
for inputs, targets in test_loader:
inputs, targets = inputs.to(device), targets.to(device)
probs = tta_model(inputs)
predictions = probs.argmax(dim=-1)
correct += predictions.eq(targets).sum().item()
total += targets.size(0)
accuracy = 100 * correct / total
return accuracy
Single Model + TTA (5 crops + flips):
97.78%
Improvement from single model: +0.26%
The most powerful ensembles combine models with different architectures. Different architectures learn different features and make different mistakes.
ensemble/diverse_ensemble.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
class DiverseEnsemble(nn.Module):
"""
Ensemble combining multiple different architectures.
Achieves diversity through:
1. Different network architectures
2. Different training seeds
3. Different hyperparameters
Args:
model_configs: Dict of {name: (model, weight)} pairs
"""
def __init__(self, model_configs):
super().__init__()
self.models = nn.ModuleDict()
self.weights = {}
for name, (model, weight) in model_configs.items():
self.models[name] = model
self.weights[name] = weight
model.eval()
total_weight = sum(self.weights.values())
self.weights = {k: v / total_weight for k, v in self.weights.items()}
def forward(self, x):
"""Get weighted averaged predictions."""
with torch.no_grad():
weighted_probs = None
for name, model in self.models.items():
logits = model(x)
probs = F.softmax(logits, dim=-1)
weighted = self.weights[name] * probs
if weighted_probs is None:
weighted_probs = weighted
else:
weighted_probs += weighted
return weighted_probs
def predict(self, x):
"""Get class predictions."""
probs = self.forward(x)
return probs.argmax(dim=-1)
def build_diverse_ensemble(device='cuda', num_classes=10):
"""
Build a diverse ensemble from multiple architectures.
Each model should be trained separately with our full pipeline.
"""
from models.pyramidnet import pyramidnet110_a270
from models.wide_resnet import wrn_28_10
from models.densenet import densenet_bc_190_k40
from models.resnet import preact_resnet18
models = OrderedDict()
pyramid = pyramidnet110_a270(num_classes=num_classes)
pyramid.load_state_dict(torch.load('checkpoints/pyramidnet110_best.pth'))
pyramid.to(device)
models['pyramidnet'] = (pyramid, 1.2)
wrn = wrn_28_10(num_classes=num_classes)
wrn.load_state_dict(torch.load('checkpoints/wrn_28_10_best.pth'))
wrn.to(device)
models['wideresnet'] = (wrn, 1.0)
dense = densenet_bc_190_k40(num_classes=num_classes)
dense.load_state_dict(torch.load('checkpoints/densenet190_best.pth'))
dense.to(device)
models['densenet'] = (dense, 0.9)
resnet = preact_resnet18(num_classes=num_classes)
resnet.load_state_dict(torch.load('checkpoints/preact_resnet18_best.pth'))
resnet.to(device)
models['resnet'] = (resnet, 0.8)
ensemble = DiverseEnsemble(models)
return ensemble
def analyze_diversity(models, test_loader, device):
"""
Analyze prediction diversity between models.
Higher disagreement = better for ensemble.
"""
import numpy as np
model_names = list(models.keys())
num_models = len(model_names)
all_predictions = {name: [] for name in model_names}
with torch.no_grad():
for inputs, targets in test_loader:
inputs = inputs.to(device)
for name, (model, _) in models.items():
outputs = model(inputs)
preds = outputs.argmax(dim=-1).cpu().numpy()
all_predictions[name].extend(preds)
for name in model_names:
all_predictions[name] = np.array(all_predictions[name])
print("\nPairwise Disagreement Matrix:")
print("(Lower = more correlated, Higher = more diverse)")
print()
for i, name1 in enumerate(model_names):
row = []
for j, name2 in enumerate(model_names):
if i == j:
row.append("-")
else:
disagreement = (
all_predictions[name1] != all_predictions[name2]
).mean() * 100
row.append(f"{disagreement:.1f}%")
print(f"{name1:15s}: {row}")
class FullEnsembleWithTTA(nn.Module):
"""
Complete ensemble with TTA - the ultimate combination.
Combines:
1. Multiple diverse architectures
2. Test-time augmentation
This gives maximum accuracy at the cost of inference speed.
"""
def __init__(self, ensemble, tta_transforms=None):
super().__init__()
self.ensemble = ensemble
if tta_transforms is None:
self.tta_transforms = [
lambda x: x,
lambda x: torch.flip(x, [-1]),
]
else:
self.tta_transforms = tta_transforms
def forward(self, x):
"""Get ensemble + TTA predictions."""
all_probs = []
with torch.no_grad():
for transform in self.tta_transforms:
augmented = transform(x)
probs = self.ensemble(augmented)
all_probs.append(probs)
return torch.stack(all_probs).mean(dim=0)
def predict(self, x):
probs = self.forward(x)
return probs.argmax(dim=-1)
4-Architecture Ensemble + TTA:
98.47%
Combined improvement: +0.95%
ensemble/full_pipeline.py
import torch
import torch.nn.functional as F
from collections import OrderedDict
from tqdm import tqdm
def create_ultimate_ensemble(device='cuda', num_classes=10):
"""
Create the ultimate ensemble combining all techniques.
Components:
1. PyramidNet-110 (best single) + EMA weights
2. WideResNet-28-10 + EMA weights
3. DenseNet-BC-190 + EMA weights
4. 5 Snapshot models from PyramidNet training
5. Test-time augmentation on all models
Expected accuracy: ~98.5%
"""
models_config = OrderedDict()
print("Loading main models...")
from models.pyramidnet import pyramidnet110_a270
pyramid = pyramidnet110_a270(num_classes=num_classes)
pyramid.load_state_dict(torch.load('checkpoints/pyramidnet110_ema.pth'))
pyramid.to(device).eval()
models_config['pyramid_ema'] = (pyramid, 1.5)
from models.wide_resnet import wrn_28_10
wrn = wrn_28_10(num_classes=num_classes)
wrn.load_state_dict(torch.load('checkpoints/wrn_28_10_ema.pth'))
wrn.to(device).eval()
models_config['wrn_ema'] = (wrn, 1.2)
from models.densenet import densenet_bc_190_k40
dense = densenet_bc_190_k40(num_classes=num_classes)
dense.load_state_dict(torch.load('checkpoints/densenet190_ema.pth'))
dense.to(device).eval()
models_config['dense_ema'] = (dense, 1.0)
print("Loading snapshot models...")
for i in range(5):
snap = pyramidnet110_a270(num_classes=num_classes)
snap.load_state_dict(torch.load(f'checkpoints/pyramid_snapshot_{i+1}.pth'))
snap.to(device).eval()
models_config[f'snapshot_{i+1}'] = (snap, 0.6)
base_ensemble = DiverseEnsemble(models_config)
tta_transforms = [
lambda x: x,
lambda x: torch.flip(x, [-1]),
]
full_ensemble = FullEnsembleWithTTA(base_ensemble, tta_transforms)
print(f"Ensemble created with {len(models_config)} models + TTA")
return full_ensemble
def evaluate_full_ensemble(ensemble, test_loader, device):
"""Evaluate the full ensemble."""
correct = 0
total = 0
pbar = tqdm(test_loader, desc='Evaluating')
with torch.no_grad():
for inputs, targets in pbar:
inputs, targets = inputs.to(device), targets.to(device)
predictions = ensemble.predict(inputs)
correct += predictions.eq(targets).sum().item()
total += targets.size(0)
pbar.set_postfix({'acc': f'{100*correct/total:.2f}%'})
accuracy = 100 * correct / total
return accuracy
def main():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
import torchvision
import torchvision.transforms as T
test_transform = T.Compose([
T.ToTensor(),
T.Normalize((0.4914, 0.4822, 0.4465),
(0.2470, 0.2435, 0.2616))
])
test_set = torchvision.datasets.CIFAR10(
'./data', train=False, download=True, transform=test_transform
)
test_loader = torch.utils.data.DataLoader(
test_set, batch_size=100, shuffle=False, num_workers=4
)
ensemble = create_ultimate_ensemble(device)
accuracy = evaluate_full_ensemble(ensemble, test_loader, device)
print(f'\n{"="*50}')
print(f'Ultimate Ensemble Accuracy: {accuracy:.2f}%')
print(f'{"="*50}')
if __name__ == '__main__':
main()
| Ensemble Configuration |
Accuracy |
Improvement |
| Single PyramidNet + All Tricks (Part 4) |
97.52% |
Baseline |
| + TTA (flip only) |
97.78% |
+0.26% |
| + Snapshot Ensemble (5 models) |
97.89% |
+0.37% |
| 3-Model Diverse Ensemble |
98.12% |
+0.60% |
| 4-Model Ensemble + TTA |
98.47% |
+0.95% |
| Ultimate Ensemble (8 models + TTA) |
98.56% |
+1.04% |
Updated Progress
Current: 98.56% | Target: 99.5%
Gap remaining: 0.94%
Part 5 Key Takeaways
- Diversity is key - Different architectures make different mistakes
- Snapshot ensembles are free - One training run, multiple models
- TTA is simple but effective - Just flip the test images
- Weighted averaging helps - Give better models more influence
- Diminishing returns - First few models help most
- Inference cost matters - More models = slower predictions
Next: Part 6 - Final Push: Beating SOTA
We're at 98.56%! In the final part, we'll push to beat state-of-the-art with:
- Advanced architectures - NFNet, ConvNeXt adaptations
- Larger models - PyramidNet-272, WRN-40-10
- AutoML augmentation - Optimized policies
- Final optimizations - Every trick combined