CIFAR CHALLENGE SERIES - PART 5

Ensemble Methods: Combining Models for Maximum Accuracy

EDUSHARK TRAINING 35 min read

Our Progress

Previous: 97.52% | Target: 99.5%

97.52%

Gap remaining: 1.98%

1

Why Ensembles Work

Ensemble methods combine predictions from multiple models to achieve better performance than any single model. The key insight: different models make different mistakes.

Conditions for Ensemble Success

  • Accuracy: Each model should be reasonably accurate (better than random)
  • Diversity: Models should make different errors (not correlated)
  • Independence: Ideally, trained with different data/architectures/hyperparameters
Single Model:                 Ensemble of 5 Models:

Input Image                   Input Image
     │                             │
     ▼                      ┌──────┼──────┐
┌─────────┐                 │      │      │
│ Model   │                 ▼      ▼      ▼
│  97.5%  │            ┌─────┐ ┌─────┐ ┌─────┐
└────┬────┘            │M1   │ │M2   │ │M3   │
     │                 │97.5%│ │97.3%│ │97.1%│
     ▼                 └──┬──┘ └──┬──┘ └──┬──┘
 Prediction                │      │      │
                           │      │      │
                           ▼      ▼      ▼
                        ┌───────────────────┐
                        │ Average/Vote      │
                        │   Predictions     │
                        └─────────┬─────────┘
                                  │
                                  ▼
                           Final: 98.3%
                        

Model Averaging

Average softmax probabilities

+0.3-0.8%

Snapshot Ensemble

Free ensemble from one training run

+0.2-0.5%

Test-Time Aug

Multiple views of same image

+0.2-0.4%

Architecture Mix

Different network designs

+0.3-0.6%
2

Basic Model Averaging

The simplest ensemble method: average the softmax probabilities (or logits) from multiple models, then take argmax.

ensemble/model_averaging.py
import torch import torch.nn as nn import torch.nn.functional as F from typing import List class EnsembleModel(nn.Module): """ Ensemble of multiple models with prediction averaging. Combines predictions from multiple models by averaging their softmax probabilities (soft voting) or logits. Args: models: List of trained models mode: 'soft' (average probabilities) or 'logit' (average logits) """ def __init__(self, models: List[nn.Module], mode='soft'): super().__init__() self.models = nn.ModuleList(models) self.mode = mode # Set all models to eval mode for model in self.models: model.eval() def forward(self, x): """Get averaged predictions from all models.""" with torch.no_grad(): if self.mode == 'soft': # Average softmax probabilities probs = [] for model in self.models: logits = model(x) probs.append(F.softmax(logits, dim=-1)) avg_probs = torch.stack(probs).mean(dim=0) return avg_probs else: # Average logits logits = [] for model in self.models: logits.append(model(x)) avg_logits = torch.stack(logits).mean(dim=0) return avg_logits def predict(self, x): """Get class predictions.""" output = self.forward(x) return output.argmax(dim=-1) class WeightedEnsemble(nn.Module): """ Weighted ensemble with learnable or fixed weights. Different models contribute differently based on their strengths. Args: models: List of trained models weights: Optional list of weights (default: uniform) learnable: If True, weights are learnable parameters """ def __init__( self, models: List[nn.Module], weights: List[float] = None, learnable: bool = False ): super().__init__() self.models = nn.ModuleList(models) self.num_models = len(models) # Initialize weights if weights is None: weights = [1.0 / self.num_models] * self.num_models if learnable: self.weights = nn.Parameter(torch.tensor(weights)) else: self.register_buffer('weights', torch.tensor(weights)) for model in self.models: model.eval() def forward(self, x): """Get weighted averaged predictions.""" # Normalize weights to sum to 1 normalized_weights = F.softmax(self.weights, dim=0) with torch.no_grad(): probs = [] for i, model in enumerate(self.models): logits = model(x) prob = F.softmax(logits, dim=-1) probs.append(normalized_weights[i] * prob) return torch.stack(probs).sum(dim=0) def evaluate_ensemble(models, test_loader, device): """Evaluate ensemble on test set.""" ensemble = EnsembleModel(models, mode='soft').to(device) correct = 0 total = 0 with torch.no_grad(): for inputs, targets in test_loader: inputs, targets = inputs.to(device), targets.to(device) predictions = ensemble.predict(inputs) correct += predictions.eq(targets).sum().item() total += targets.size(0) accuracy = 100 * correct / total return accuracy def find_optimal_weights(models, val_loader, device, num_iterations=100): """ Find optimal ensemble weights using grid search on validation set. Returns weights that maximize validation accuracy. """ import numpy as np from itertools import product num_models = len(models) # Generate weight combinations weight_range = np.linspace(0.1, 1.0, 10) best_acc = 0 best_weights = [1.0 / num_models] * num_models # For small number of models, try grid search if num_models <= 3: for weights in product(weight_range, repeat=num_models): weights = list(weights) total = sum(weights) weights = [w / total for w in weights] # Normalize ensemble = WeightedEnsemble(models, weights=weights).to(device) acc = evaluate_single_ensemble(ensemble, val_loader, device) if acc > best_acc: best_acc = acc best_weights = weights.copy() else: # Random search for more models for _ in range(num_iterations): weights = np.random.dirichlet(np.ones(num_models)) weights = weights.tolist() ensemble = WeightedEnsemble(models, weights=weights).to(device) acc = evaluate_single_ensemble(ensemble, val_loader, device) if acc > best_acc: best_acc = acc best_weights = weights.copy() return best_weights, best_acc def evaluate_single_ensemble(ensemble, loader, device): """Helper to evaluate a single ensemble configuration.""" correct = 0 total = 0 with torch.no_grad(): for inputs, targets in loader: inputs, targets = inputs.to(device), targets.to(device) probs = ensemble(inputs) predictions = probs.argmax(dim=-1) correct += predictions.eq(targets).sum().item() total += targets.size(0) return 100 * correct / total

3-Model Ensemble (PyramidNet + WideResNet + DenseNet):

98.12%

Improvement from single model: +0.60%

3

Snapshot Ensembles

Snapshot Ensembles (Huang et al., 2017) create an ensemble from a single training run using cyclic learning rates. At each cycle minimum, we save a "snapshot" of the model - these snapshots form our ensemble.

Why Snapshot Ensembles Work

The cyclic learning rate causes the model to explore different local minima. Each snapshot converges to a different minimum, providing natural diversity.

  • Free ensemble: Only train once, get multiple models
  • Diverse minima: Different weight configurations
  • Same architecture: Easy to combine
Cyclic Learning Rate with Snapshots:

LR
│
│ ▓▓                    ▓▓                    ▓▓
│   ▓▓                ▓▓  ▓▓                ▓▓  ▓▓
│     ▓▓            ▓▓      ▓▓            ▓▓      ▓▓
│       ▓▓        ▓▓          ▓▓        ▓▓          ▓
│         ▓▓    ▓▓              ▓▓    ▓▓
│           ▓▓▓▓                  ▓▓▓▓
│             *                     *                *
└───────────────────────────────────────────────────────
            Snapshot 1          Snapshot 2      Snapshot 3

* = Save model at cycle minimum (local optimum)
                        
ensemble/snapshot.py
import torch import torch.nn as nn import torch.optim as optim import math from copy import deepcopy class CyclicCosineScheduler: """ Cyclic cosine annealing scheduler for snapshot ensembles. Learning rate follows cosine curve within each cycle, restarting at max_lr at the beginning of each cycle. Args: optimizer: PyTorch optimizer epochs_per_cycle: Epochs per restart cycle num_cycles: Total number of cycles min_lr: Minimum learning rate at cycle end """ def __init__( self, optimizer, epochs_per_cycle, num_cycles, min_lr=1e-6 ): self.optimizer = optimizer self.epochs_per_cycle = epochs_per_cycle self.num_cycles = num_cycles self.min_lr = min_lr self.max_lr = optimizer.param_groups[0]['lr'] self.current_epoch = 0 def step(self): """Update learning rate for current epoch.""" # Position within current cycle (0 to 1) cycle_position = (self.current_epoch % self.epochs_per_cycle) / self.epochs_per_cycle # Cosine decay within cycle lr = self.min_lr + 0.5 * (self.max_lr - self.min_lr) * ( 1 + math.cos(math.pi * cycle_position) ) for param_group in self.optimizer.param_groups: param_group['lr'] = lr self.current_epoch += 1 return lr def is_cycle_end(self): """Check if we're at the end of a cycle (time to save snapshot).""" return self.current_epoch % self.epochs_per_cycle == 0 and self.current_epoch > 0 def get_lr(self): return self.optimizer.param_groups[0]['lr'] class SnapshotEnsembleTrainer: """ Trainer for Snapshot Ensembles. Trains a single model with cyclic learning rate and saves snapshots at each cycle minimum. Paper: "Snapshot Ensembles: Train 1, get M for free" Args: model: Model to train num_cycles: Number of cycles (= number of snapshots) epochs_per_cycle: Training epochs per cycle device: Training device """ def __init__( self, model, num_cycles=5, epochs_per_cycle=40, device='cuda' ): self.model = model.to(device) self.num_cycles = num_cycles self.epochs_per_cycle = epochs_per_cycle self.device = device self.snapshots = [] def train( self, train_loader, test_loader, base_lr=0.1, weight_decay=5e-4 ): """Train model and collect snapshots.""" total_epochs = self.num_cycles * self.epochs_per_cycle # Optimizer optimizer = optim.SGD( self.model.parameters(), lr=base_lr, momentum=0.9, weight_decay=weight_decay, nesterov=True ) # Cyclic scheduler scheduler = CyclicCosineScheduler( optimizer, epochs_per_cycle=self.epochs_per_cycle, num_cycles=self.num_cycles ) criterion = nn.CrossEntropyLoss() for epoch in range(total_epochs): # Train one epoch self.model.train() for inputs, targets in train_loader: inputs, targets = inputs.to(self.device), targets.to(self.device) optimizer.zero_grad() outputs = self.model(inputs) loss = criterion(outputs, targets) loss.backward() optimizer.step() # Update LR lr = scheduler.step() # Evaluate acc = self._evaluate(test_loader) cycle_num = epoch // self.epochs_per_cycle + 1 print(f'Epoch {epoch+1}/{total_epochs} | LR: {lr:.6f} | Acc: {acc:.2f}% | Cycle: {cycle_num}') # Save snapshot at cycle end if scheduler.is_cycle_end(): snapshot = deepcopy(self.model.state_dict()) self.snapshots.append(snapshot) print(f' >> Saved snapshot {len(self.snapshots)} (Acc: {acc:.2f}%)') return self.snapshots def _evaluate(self, loader): """Evaluate model accuracy.""" self.model.eval() correct = 0 total = 0 with torch.no_grad(): for inputs, targets in loader: inputs, targets = inputs.to(self.device), targets.to(self.device) outputs = self.model(inputs) _, predicted = outputs.max(1) correct += predicted.eq(targets).sum().item() total += targets.size(0) return 100 * correct / total def get_ensemble(self, model_class, model_args): """ Create ensemble from saved snapshots. Args: model_class: Class to instantiate models model_args: Arguments for model constructor Returns: EnsembleModel with all snapshots """ models = [] for i, snapshot in enumerate(self.snapshots): model = model_class(**model_args) model.load_state_dict(snapshot) model.to(self.device) model.eval() models.append(model) print(f'Loaded snapshot {i+1}') return EnsembleModel(models, mode='soft') # Example usage def train_snapshot_ensemble(): from models.pyramidnet import pyramidnet110_a270 # Create model model = pyramidnet110_a270(num_classes=10) # Create trainer trainer = SnapshotEnsembleTrainer( model, num_cycles=5, epochs_per_cycle=40, device='cuda' ) # Train and collect snapshots snapshots = trainer.train(train_loader, test_loader) # Create ensemble ensemble = trainer.get_ensemble( model_class=pyramidnet110_a270, model_args={'num_classes': 10} ) # Evaluate ensemble acc = evaluate_ensemble([ensemble], test_loader, device) print(f'Snapshot Ensemble Accuracy: {acc:.2f}%')

5-Snapshot Ensemble (PyramidNet):

97.89%

Free +0.37% over single model!

4

Test-Time Augmentation (TTA)

Test-Time Augmentation applies augmentations to test images and averages predictions across all augmented versions. This is like creating an ensemble from different views of the same image.

ensemble/tta.py
import torch import torch.nn as nn import torch.nn.functional as F import torchvision.transforms as T from typing import List, Callable class TTAWrapper(nn.Module): """ Test-Time Augmentation wrapper for any model. Applies multiple augmentations to each test image and averages the predictions. Args: model: Trained model transforms: List of transforms to apply merge_mode: 'mean' (average probs) or 'max' (max voting) """ def __init__( self, model: nn.Module, transforms: List[Callable] = None, merge_mode: str = 'mean' ): super().__init__() self.model = model self.merge_mode = merge_mode # Default TTA transforms if transforms is None: self.transforms = [ lambda x: x, # Original lambda x: torch.flip(x, [-1]), # Horizontal flip ] else: self.transforms = transforms self.model.eval() def forward(self, x): """Get averaged predictions across all transforms.""" all_probs = [] with torch.no_grad(): for transform in self.transforms: augmented = transform(x) logits = self.model(augmented) probs = F.softmax(logits, dim=-1) all_probs.append(probs) # Stack and merge stacked = torch.stack(all_probs, dim=0) if self.merge_mode == 'mean': return stacked.mean(dim=0) elif self.merge_mode == 'max': return stacked.max(dim=0)[0] else: raise ValueError(f'Unknown merge mode: {self.merge_mode}') class AdvancedTTA(nn.Module): """ Advanced TTA with more augmentation options. Includes: - Horizontal flip - Small crops at different positions - Color variations """ def __init__(self, model, num_crops=5, use_flips=True): super().__init__() self.model = model self.num_crops = num_crops self.use_flips = use_flips self.model.eval() def _get_crops(self, x, crop_size=32, num_crops=5): """Get multiple crops from padded image.""" # Pad and get random crops B, C, H, W = x.shape pad = 4 padded = F.pad(x, [pad, pad, pad, pad], mode='reflect') crops = [x] # Original center crop if num_crops > 1: # Corner crops positions = [ (0, 0), # Top-left (0, 2 * pad), # Top-right (2 * pad, 0), # Bottom-left (2 * pad, 2 * pad), # Bottom-right ] for i, (top, left) in enumerate(positions[:num_crops-1]): crop = padded[:, :, top:top+crop_size, left:left+crop_size] crops.append(crop) return crops def forward(self, x): """Get TTA predictions.""" all_probs = [] with torch.no_grad(): # Get crops crops = self._get_crops(x, num_crops=self.num_crops) for crop in crops: # Original logits = self.model(crop) all_probs.append(F.softmax(logits, dim=-1)) # Horizontal flip if self.use_flips: flipped = torch.flip(crop, [-1]) logits = self.model(flipped) all_probs.append(F.softmax(logits, dim=-1)) # Average all predictions stacked = torch.stack(all_probs, dim=0) return stacked.mean(dim=0) def evaluate_with_tta(model, test_loader, device, num_crops=5): """Evaluate model with test-time augmentation.""" tta_model = AdvancedTTA(model, num_crops=num_crops, use_flips=True).to(device) correct = 0 total = 0 with torch.no_grad(): for inputs, targets in test_loader: inputs, targets = inputs.to(device), targets.to(device) probs = tta_model(inputs) predictions = probs.argmax(dim=-1) correct += predictions.eq(targets).sum().item() total += targets.size(0) accuracy = 100 * correct / total return accuracy

Single Model + TTA (5 crops + flips):

97.78%

Improvement from single model: +0.26%

5

Diverse Architecture Ensemble

The most powerful ensembles combine models with different architectures. Different architectures learn different features and make different mistakes.

ensemble/diverse_ensemble.py
import torch import torch.nn as nn import torch.nn.functional as F from collections import OrderedDict class DiverseEnsemble(nn.Module): """ Ensemble combining multiple different architectures. Achieves diversity through: 1. Different network architectures 2. Different training seeds 3. Different hyperparameters Args: model_configs: Dict of {name: (model, weight)} pairs """ def __init__(self, model_configs): super().__init__() self.models = nn.ModuleDict() self.weights = {} for name, (model, weight) in model_configs.items(): self.models[name] = model self.weights[name] = weight model.eval() # Normalize weights total_weight = sum(self.weights.values()) self.weights = {k: v / total_weight for k, v in self.weights.items()} def forward(self, x): """Get weighted averaged predictions.""" with torch.no_grad(): weighted_probs = None for name, model in self.models.items(): logits = model(x) probs = F.softmax(logits, dim=-1) weighted = self.weights[name] * probs if weighted_probs is None: weighted_probs = weighted else: weighted_probs += weighted return weighted_probs def predict(self, x): """Get class predictions.""" probs = self.forward(x) return probs.argmax(dim=-1) def build_diverse_ensemble(device='cuda', num_classes=10): """ Build a diverse ensemble from multiple architectures. Each model should be trained separately with our full pipeline. """ from models.pyramidnet import pyramidnet110_a270 from models.wide_resnet import wrn_28_10 from models.densenet import densenet_bc_190_k40 from models.resnet import preact_resnet18 # Load trained models models = OrderedDict() # PyramidNet-110 (best single model) pyramid = pyramidnet110_a270(num_classes=num_classes) pyramid.load_state_dict(torch.load('checkpoints/pyramidnet110_best.pth')) pyramid.to(device) models['pyramidnet'] = (pyramid, 1.2) # Higher weight for best model # WideResNet-28-10 wrn = wrn_28_10(num_classes=num_classes) wrn.load_state_dict(torch.load('checkpoints/wrn_28_10_best.pth')) wrn.to(device) models['wideresnet'] = (wrn, 1.0) # DenseNet-BC-190 dense = densenet_bc_190_k40(num_classes=num_classes) dense.load_state_dict(torch.load('checkpoints/densenet190_best.pth')) dense.to(device) models['densenet'] = (dense, 0.9) # PreAct-ResNet-18 (smaller, different perspective) resnet = preact_resnet18(num_classes=num_classes) resnet.load_state_dict(torch.load('checkpoints/preact_resnet18_best.pth')) resnet.to(device) models['resnet'] = (resnet, 0.8) # Create ensemble ensemble = DiverseEnsemble(models) return ensemble def analyze_diversity(models, test_loader, device): """ Analyze prediction diversity between models. Higher disagreement = better for ensemble. """ import numpy as np model_names = list(models.keys()) num_models = len(model_names) all_predictions = {name: [] for name in model_names} with torch.no_grad(): for inputs, targets in test_loader: inputs = inputs.to(device) for name, (model, _) in models.items(): outputs = model(inputs) preds = outputs.argmax(dim=-1).cpu().numpy() all_predictions[name].extend(preds) # Convert to arrays for name in model_names: all_predictions[name] = np.array(all_predictions[name]) # Compute pairwise disagreement print("\nPairwise Disagreement Matrix:") print("(Lower = more correlated, Higher = more diverse)") print() for i, name1 in enumerate(model_names): row = [] for j, name2 in enumerate(model_names): if i == j: row.append("-") else: disagreement = ( all_predictions[name1] != all_predictions[name2] ).mean() * 100 row.append(f"{disagreement:.1f}%") print(f"{name1:15s}: {row}") class FullEnsembleWithTTA(nn.Module): """ Complete ensemble with TTA - the ultimate combination. Combines: 1. Multiple diverse architectures 2. Test-time augmentation This gives maximum accuracy at the cost of inference speed. """ def __init__(self, ensemble, tta_transforms=None): super().__init__() self.ensemble = ensemble if tta_transforms is None: self.tta_transforms = [ lambda x: x, # Original lambda x: torch.flip(x, [-1]), # H-flip ] else: self.tta_transforms = tta_transforms def forward(self, x): """Get ensemble + TTA predictions.""" all_probs = [] with torch.no_grad(): for transform in self.tta_transforms: augmented = transform(x) probs = self.ensemble(augmented) all_probs.append(probs) # Average across TTA return torch.stack(all_probs).mean(dim=0) def predict(self, x): probs = self.forward(x) return probs.argmax(dim=-1)

4-Architecture Ensemble + TTA:

98.47%

Combined improvement: +0.95%

6

Complete Ensemble Pipeline

ensemble/full_pipeline.py
import torch import torch.nn.functional as F from collections import OrderedDict from tqdm import tqdm def create_ultimate_ensemble(device='cuda', num_classes=10): """ Create the ultimate ensemble combining all techniques. Components: 1. PyramidNet-110 (best single) + EMA weights 2. WideResNet-28-10 + EMA weights 3. DenseNet-BC-190 + EMA weights 4. 5 Snapshot models from PyramidNet training 5. Test-time augmentation on all models Expected accuracy: ~98.5% """ models_config = OrderedDict() # Main models with EMA weights print("Loading main models...") # PyramidNet from models.pyramidnet import pyramidnet110_a270 pyramid = pyramidnet110_a270(num_classes=num_classes) pyramid.load_state_dict(torch.load('checkpoints/pyramidnet110_ema.pth')) pyramid.to(device).eval() models_config['pyramid_ema'] = (pyramid, 1.5) # WideResNet from models.wide_resnet import wrn_28_10 wrn = wrn_28_10(num_classes=num_classes) wrn.load_state_dict(torch.load('checkpoints/wrn_28_10_ema.pth')) wrn.to(device).eval() models_config['wrn_ema'] = (wrn, 1.2) # DenseNet from models.densenet import densenet_bc_190_k40 dense = densenet_bc_190_k40(num_classes=num_classes) dense.load_state_dict(torch.load('checkpoints/densenet190_ema.pth')) dense.to(device).eval() models_config['dense_ema'] = (dense, 1.0) # Snapshot models print("Loading snapshot models...") for i in range(5): snap = pyramidnet110_a270(num_classes=num_classes) snap.load_state_dict(torch.load(f'checkpoints/pyramid_snapshot_{i+1}.pth')) snap.to(device).eval() models_config[f'snapshot_{i+1}'] = (snap, 0.6) # Create base ensemble base_ensemble = DiverseEnsemble(models_config) # Wrap with TTA tta_transforms = [ lambda x: x, # Original lambda x: torch.flip(x, [-1]), # H-flip ] full_ensemble = FullEnsembleWithTTA(base_ensemble, tta_transforms) print(f"Ensemble created with {len(models_config)} models + TTA") return full_ensemble def evaluate_full_ensemble(ensemble, test_loader, device): """Evaluate the full ensemble.""" correct = 0 total = 0 pbar = tqdm(test_loader, desc='Evaluating') with torch.no_grad(): for inputs, targets in pbar: inputs, targets = inputs.to(device), targets.to(device) predictions = ensemble.predict(inputs) correct += predictions.eq(targets).sum().item() total += targets.size(0) pbar.set_postfix({'acc': f'{100*correct/total:.2f}%'}) accuracy = 100 * correct / total return accuracy def main(): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Load test data import torchvision import torchvision.transforms as T test_transform = T.Compose([ T.ToTensor(), T.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)) ]) test_set = torchvision.datasets.CIFAR10( './data', train=False, download=True, transform=test_transform ) test_loader = torch.utils.data.DataLoader( test_set, batch_size=100, shuffle=False, num_workers=4 ) # Create and evaluate ensemble ensemble = create_ultimate_ensemble(device) accuracy = evaluate_full_ensemble(ensemble, test_loader, device) print(f'\n{"="*50}') print(f'Ultimate Ensemble Accuracy: {accuracy:.2f}%') print(f'{"="*50}') if __name__ == '__main__': main()
7

Results Summary

Ensemble Configuration Accuracy Improvement
Single PyramidNet + All Tricks (Part 4) 97.52% Baseline
+ TTA (flip only) 97.78% +0.26%
+ Snapshot Ensemble (5 models) 97.89% +0.37%
3-Model Diverse Ensemble 98.12% +0.60%
4-Model Ensemble + TTA 98.47% +0.95%
Ultimate Ensemble (8 models + TTA) 98.56% +1.04%

Updated Progress

Current: 98.56% | Target: 99.5%

98.56%

Gap remaining: 0.94%

Part 5 Key Takeaways

  • Diversity is key - Different architectures make different mistakes
  • Snapshot ensembles are free - One training run, multiple models
  • TTA is simple but effective - Just flip the test images
  • Weighted averaging helps - Give better models more influence
  • Diminishing returns - First few models help most
  • Inference cost matters - More models = slower predictions

Next: Part 6 - Final Push: Beating SOTA

We're at 98.56%! In the final part, we'll push to beat state-of-the-art with:

  • Advanced architectures - NFNet, ConvNeXt adaptations
  • Larger models - PyramidNet-272, WRN-40-10
  • AutoML augmentation - Optimized policies
  • Final optimizations - Every trick combined