CIFAR CHALLENGE SERIES - PART 4
Training Tricks: Label Smoothing, Distillation, SWA & More
EDUSHARK TRAINING
40 min read
Our Progress
Previous: 96.38% | Target: 99.5%
Gap remaining: 3.12%
We have a powerful PyramidNet architecture at 96.38%. Now we'll apply advanced training techniques to push even further. Each trick addresses a specific limitation:
Label Smoothing
Prevents overconfident predictions
+0.3-0.5%
Knowledge Distillation
Transfer knowledge from larger models
+0.5-1.0%
Stochastic Depth
Regularization through layer dropout
+0.2-0.4%
LR Schedules
Optimized learning rate decay
+0.2-0.5%
SWA/EMA
Weight averaging for better generalization
+0.3-0.6%
Gradient Clipping
Stabilize training of deep networks
Stability
Label Smoothing (Szegedy et al., 2016) prevents the model from becoming overconfident by softening the target distribution.
The Problem with Hard Labels
Standard cross-entropy with hard labels (one-hot vectors) encourages the model to:
- Push logits to extreme values (+inf for correct class, -inf for others)
- Become overconfident (100% certainty even when wrong)
- Overfit to the training labels
Hard Labels: y = [0, 0, 1, 0, 0, 0, 0, 0, 0, 0] (100% cat)
Soft Labels (epsilon=0.1):
y_smooth = (1 - epsilon) * y + epsilon / num_classes
y_smooth = [0.01, 0.01, 0.91, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01]
training/label_smoothing.py
import torch
import torch.nn as nn
import torch.nn.functional as F
class LabelSmoothingCrossEntropy(nn.Module):
"""
Cross-entropy loss with label smoothing.
Instead of one-hot targets, uses soft targets:
y_smooth = (1 - epsilon) * y_one_hot + epsilon / num_classes
This prevents overconfident predictions and improves calibration.
Args:
epsilon: Smoothing factor (0 = no smoothing, 1 = uniform)
reduction: 'mean', 'sum', or 'none'
"""
def __init__(self, epsilon=0.1, reduction='mean'):
super().__init__()
self.epsilon = epsilon
self.reduction = reduction
def forward(self, logits, targets):
"""
Args:
logits: (batch_size, num_classes) raw model outputs
targets: (batch_size,) class indices
Returns:
Smoothed cross-entropy loss
"""
num_classes = logits.size(-1)
log_probs = F.log_softmax(logits, dim=-1)
with torch.no_grad():
smooth_targets = torch.full_like(log_probs, self.epsilon / num_classes)
smooth_targets.scatter_(
1,
targets.unsqueeze(1),
1.0 - self.epsilon + self.epsilon / num_classes
)
loss = -torch.sum(smooth_targets * log_probs, dim=-1)
if self.reduction == 'mean':
return loss.mean()
elif self.reduction == 'sum':
return loss.sum()
else:
return loss
class LabelSmoothingCrossEntropyV2(nn.Module):
"""
Alternative implementation using KL divergence perspective.
Loss = (1 - epsilon) * CE(p, q) + epsilon * H(p, uniform)
This is mathematically equivalent but sometimes more numerically stable.
"""
def __init__(self, epsilon=0.1, reduction='mean'):
super().__init__()
self.epsilon = epsilon
self.reduction = reduction
self.log_softmax = nn.LogSoftmax(dim=-1)
def forward(self, logits, targets):
num_classes = logits.size(-1)
log_probs = self.log_softmax(logits)
nll_loss = F.nll_loss(log_probs, targets, reduction=self.reduction)
smooth_loss = -log_probs.mean(dim=-1)
if self.reduction == 'mean':
smooth_loss = smooth_loss.mean()
elif self.reduction == 'sum':
smooth_loss = smooth_loss.sum()
loss = (1.0 - self.epsilon) * nll_loss + self.epsilon * smooth_loss
return loss
def train_with_label_smoothing():
model = WideResNet(28, 10)
criterion = LabelSmoothingCrossEntropy(epsilon=0.1)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
for inputs, targets in train_loader:
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
optimizer.zero_grad()
Label Smoothing (epsilon=0.1) with PyramidNet-110:
96.72%
Improvement: +0.34%
Knowledge Distillation (Hinton et al., 2015) transfers knowledge from a large "teacher" model to a smaller "student" model. The key insight: soft targets from the teacher contain more information than hard labels.
Why Distillation Works
The teacher's soft predictions encode relationships between classes:
- A cat image might get: [0.05 dog, 0.8 cat, 0.05 tiger, 0.02 car, ...]
- The small dog probability tells the student "cats look somewhat like dogs"
- This "dark knowledge" helps the student learn better representations
Distillation Loss:
L = alpha * L_hard + (1 - alpha) * T^2 * L_soft
Where:
- L_hard = CrossEntropy(student_logits, true_labels)
- L_soft = KL_Divergence(softmax(student_logits/T), softmax(teacher_logits/T))
- T = Temperature (typically 3-20)
- alpha = Balance factor (typically 0.1-0.5)
training/distillation.py
import torch
import torch.nn as nn
import torch.nn.functional as F
class DistillationLoss(nn.Module):
"""
Knowledge Distillation Loss.
Combines hard label loss (cross-entropy) with soft label loss (KL divergence
between student and teacher softmax outputs at temperature T).
Args:
temperature: Softmax temperature (higher = softer distributions)
alpha: Weight for hard label loss (1-alpha for soft loss)
"""
def __init__(self, temperature=4.0, alpha=0.1):
super().__init__()
self.temperature = temperature
self.alpha = alpha
self.ce_loss = nn.CrossEntropyLoss()
self.kl_loss = nn.KLDivLoss(reduction='batchmean')
def forward(self, student_logits, teacher_logits, targets):
"""
Args:
student_logits: Raw outputs from student model
teacher_logits: Raw outputs from teacher model (no grad needed)
targets: True class labels
Returns:
Combined distillation loss
"""
hard_loss = self.ce_loss(student_logits, targets)
student_soft = F.log_softmax(student_logits / self.temperature, dim=-1)
with torch.no_grad():
teacher_soft = F.softmax(teacher_logits / self.temperature, dim=-1)
soft_loss = self.kl_loss(student_soft, teacher_soft) * (self.temperature ** 2)
loss = self.alpha * hard_loss + (1 - self.alpha) * soft_loss
return loss
class SelfDistillationLoss(nn.Module):
"""
Self-Distillation: Use the model's own predictions as soft targets.
Train with a mix of:
1. Cross-entropy with true labels
2. Consistency with model's predictions at lower temperature
This acts as a form of regularization.
"""
def __init__(self, temperature=3.0, alpha=0.5):
super().__init__()
self.temperature = temperature
self.alpha = alpha
def forward(self, logits, targets, prev_logits=None):
hard_loss = F.cross_entropy(logits, targets)
if prev_logits is None:
return hard_loss
student_soft = F.log_softmax(logits / self.temperature, dim=-1)
teacher_soft = F.softmax(prev_logits.detach() / self.temperature, dim=-1)
soft_loss = F.kl_div(student_soft, teacher_soft, reduction='batchmean')
soft_loss = soft_loss * (self.temperature ** 2)
return self.alpha * hard_loss + (1 - self.alpha) * soft_loss
3.1 Complete Distillation Training Pipeline
training/distillation.py (continued)
import torch
from tqdm import tqdm
class DistillationTrainer:
"""
Knowledge Distillation Trainer.
Trains a student network to match both:
1. True labels (hard targets)
2. Teacher's soft predictions
Example usage:
teacher = load_pretrained_wrn_28_10()
student = create_resnet18()
trainer = DistillationTrainer(teacher, student, ...)
trainer.train(train_loader, epochs=200)
"""
def __init__(
self,
teacher,
student,
temperature=4.0,
alpha=0.1,
device='cuda'
):
self.teacher = teacher.to(device).eval()
self.student = student.to(device)
self.device = device
for param in self.teacher.parameters():
param.requires_grad = False
self.criterion = DistillationLoss(temperature=temperature, alpha=alpha)
def train_epoch(self, loader, optimizer, scheduler=None):
self.student.train()
total_loss = 0
correct = 0
total = 0
pbar = tqdm(loader, desc='Distillation')
for inputs, targets in pbar:
inputs = inputs.to(self.device)
targets = targets.to(self.device)
with torch.no_grad():
teacher_logits = self.teacher(inputs)
student_logits = self.student(inputs)
loss = self.criterion(student_logits, teacher_logits, targets)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item() * inputs.size(0)
_, predicted = student_logits.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
pbar.set_postfix({
'loss': f'{loss.item():.4f}',
'acc': f'{100*correct/total:.2f}%'
})
if scheduler:
scheduler.step()
return total_loss / total, 100 * correct / total
def evaluate(self, loader):
self.student.eval()
correct = 0
total = 0
with torch.no_grad():
for inputs, targets in loader:
inputs = inputs.to(self.device)
targets = targets.to(self.device)
outputs = self.student(inputs)
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
return 100 * correct / total
def train(self, train_loader, test_loader, epochs, optimizer, scheduler=None):
best_acc = 0
for epoch in range(epochs):
print(f'\nEpoch {epoch + 1}/{epochs}')
train_loss, train_acc = self.train_epoch(
train_loader, optimizer, scheduler
)
test_acc = self.evaluate(test_loader)
print(f'Train Loss: {train_loss:.4f} | Test Acc: {test_acc:.2f}%')
if test_acc > best_acc:
best_acc = test_acc
torch.save(self.student.state_dict(), 'distilled_model.pth')
print(f'New best: {best_acc:.2f}%')
return best_acc
def run_distillation():
teacher = wrn_28_10(num_classes=10)
teacher.load_state_dict(torch.load('wrn_28_10_best.pth'))
student = pyramidnet110_a270(num_classes=10)
trainer = DistillationTrainer(
teacher=teacher,
student=student,
temperature=4.0,
alpha=0.1
)
optimizer = torch.optim.SGD(
student.parameters(),
lr=0.1, momentum=0.9, weight_decay=5e-4
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)
best_acc = trainer.train(
train_loader, test_loader,
epochs=200,
optimizer=optimizer,
scheduler=scheduler
)
print(f'Final best accuracy: {best_acc:.2f}%')
Knowledge Distillation (T=4, alpha=0.1):
97.05%
Improvement: +0.33% (cumulative: +0.67%)
Stochastic Depth (Huang et al., 2016) randomly drops entire layers during training. This acts as regularization and reduces training time while improving generalization.
How It Works
During training, each residual block has a probability of being "skipped" (identity function). The survival probability decreases linearly with depth:
- Early layers (close to input): High survival probability (~1.0)
- Deep layers: Lower survival probability (down to ~0.5)
- At test time: All layers are used with scaled outputs
models/stochastic_depth.py
import torch
import torch.nn as nn
import torch.nn.functional as F
class StochasticDepthBlock(nn.Module):
"""
Residual block with stochastic depth (layer dropout).
During training: skip the block with probability (1 - survival_prob)
During inference: scale output by survival_prob
This is equivalent to Dropout but at the layer level.
"""
def __init__(self, block, survival_prob=1.0):
"""
Args:
block: The residual block to wrap
survival_prob: Probability of keeping this block during training
"""
super().__init__()
self.block = block
self.survival_prob = survival_prob
def forward(self, x):
if not self.training:
return x + self.survival_prob * self.block(x)
if torch.rand(1).item() > self.survival_prob:
return x
else:
return x + self.block(x) / self.survival_prob
def add_stochastic_depth(model, survival_prob_last=0.5):
"""
Add stochastic depth to an existing ResNet-style model.
Wraps each residual block with StochasticDepthBlock.
Survival probability decreases linearly from 1.0 to survival_prob_last.
Args:
model: ResNet/WideResNet/PyramidNet model
survival_prob_last: Survival probability for the last block
Returns:
Modified model with stochastic depth
"""
blocks = []
for name, module in model.named_modules():
if 'BasicBlock' in type(module).__name__ or \
'Bottleneck' in type(module).__name__:
blocks.append((name, module))
total_blocks = len(blocks)
for idx, (name, block) in enumerate(blocks):
survival_prob = 1.0 - (idx / total_blocks) * (1.0 - survival_prob_last)
parent = model
parts = name.split('.')
for part in parts[:-1]:
parent = getattr(parent, part)
setattr(parent, parts[-1], StochasticDepthBlock(block, survival_prob))
return model
class StochasticDepthResNet(nn.Module):
"""
ResNet with built-in stochastic depth.
Alternative: build stochastic depth into the architecture from scratch.
"""
def __init__(self, block, layers, num_classes=10, survival_prob_last=0.5):
super().__init__()
self.in_channels = 64
total_blocks = sum(layers)
self.block_idx = 0
self.total_blocks = total_blocks
self.survival_prob_last = survival_prob_last
self.conv1 = nn.Conv2d(3, 64, 3, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.layer1 = self._make_layer(block, 64, layers[0], stride=1)
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * block.expansion, num_classes)
def _get_survival_prob(self):
"""Get survival probability for current block (linear decay)."""
self.block_idx += 1
prob = 1.0 - (self.block_idx / self.total_blocks) * (1.0 - self.survival_prob_last)
return prob
def _make_layer(self, block, channels, num_blocks, stride):
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for s in strides:
survival_prob = self._get_survival_prob()
b = block(self.in_channels, channels, s)
layers.append(StochasticDepthBlock(b, survival_prob))
self.in_channels = channels * block.expansion
return nn.Sequential(*layers)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
out = self.avgpool(out)
out = out.view(out.size(0), -1)
out = self.fc(out)
return out
Stochastic Depth (p_last=0.5):
97.21%
Improvement: +0.16% (cumulative: +0.83%)
The learning rate schedule significantly impacts final accuracy. We'll implement several advanced schedules beyond basic step decay.
Learning Rate Over Time:
Step Decay: Cosine Annealing:
LR LR
│▓▓▓▓▓▓▓▓▓▓ │▓▓▓▓▓▓▓▓▓
│ ▓▓▓▓▓▓ │ ▓▓▓▓▓
│ ▓▓▓▓▓▓ │ ▓▓▓▓
│ ▓▓▓ │ ▓▓▓▓
└───────────────────────── └────────────────────────
Epochs Epochs
Warmup + Cosine: Cosine with Restarts:
LR LR
│ ▓▓▓▓▓▓▓▓ │▓▓▓▓ ▓▓▓▓ ▓▓▓
│ ▓ ▓▓▓ │ ▓▓▓▓ ▓▓▓▓ ▓▓
│ ▓ ▓▓▓▓ │
│ ▓ ▓▓▓▓ │
│ ▓ │
│ ▓ │
│▓ │
└───────────────────────── └────────────────────────
Warmup Epochs Restart periods
training/schedulers.py
import math
import torch
from torch.optim.lr_scheduler import _LRScheduler
class WarmupCosineScheduler(_LRScheduler):
"""
Cosine annealing with linear warmup.
1. Linear warmup: LR increases from 0 to base_lr over warmup_epochs
2. Cosine decay: LR decreases from base_lr to min_lr following cosine curve
This is one of the most effective schedules for training deep networks.
"""
def __init__(
self,
optimizer,
warmup_epochs,
total_epochs,
min_lr=0.0,
last_epoch=-1
):
self.warmup_epochs = warmup_epochs
self.total_epochs = total_epochs
self.min_lr = min_lr
super().__init__(optimizer, last_epoch)
def get_lr(self):
if self.last_epoch < self.warmup_epochs:
alpha = self.last_epoch / self.warmup_epochs
return [base_lr * alpha for base_lr in self.base_lrs]
else:
progress = (self.last_epoch - self.warmup_epochs) / (
self.total_epochs - self.warmup_epochs
)
cosine_decay = 0.5 * (1 + math.cos(math.pi * progress))
return [
self.min_lr + (base_lr - self.min_lr) * cosine_decay
for base_lr in self.base_lrs
]
class CosineAnnealingWarmRestarts(_LRScheduler):
"""
Cosine annealing with warm restarts (SGDR).
Paper: "SGDR: Stochastic Gradient Descent with Warm Restarts"
LR follows cosine curve within each restart period.
Period can increase after each restart (T_mult > 1).
"""
def __init__(
self,
optimizer,
T_0,
T_mult=1,
eta_min=0,
last_epoch=-1
):
self.T_0 = T_0
self.T_i = T_0
self.T_mult = T_mult
self.eta_min = eta_min
self.T_cur = 0
super().__init__(optimizer, last_epoch)
def get_lr(self):
return [
self.eta_min + (base_lr - self.eta_min) *
(1 + math.cos(math.pi * self.T_cur / self.T_i)) / 2
for base_lr in self.base_lrs
]
def step(self, epoch=None):
if epoch is None:
epoch = self.last_epoch + 1
self.T_cur += 1
if self.T_cur >= self.T_i:
self.T_cur = 0
self.T_i = self.T_i * self.T_mult
self.last_epoch = epoch
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
param_group['lr'] = lr
class OneCycleLR(_LRScheduler):
"""
One Cycle Learning Rate Policy.
Paper: "Super-Convergence" (Smith & Topin, 2018)
Three phases:
1. Linear increase from initial_lr to max_lr
2. Linear decrease from max_lr to initial_lr
3. Anneal from initial_lr to final_lr
Often enables training with much higher learning rates.
"""
def __init__(
self,
optimizer,
max_lr,
total_steps,
pct_start=0.3,
div_factor=25,
final_div_factor=10000,
last_epoch=-1
):
self.max_lr = max_lr
self.total_steps = total_steps
self.pct_start = pct_start
self.div_factor = div_factor
self.final_div_factor = final_div_factor
self.initial_lr = max_lr / div_factor
self.final_lr = self.initial_lr / final_div_factor
self.step_count = 0
super().__init__(optimizer, last_epoch)
def get_lr(self):
pct = self.step_count / self.total_steps
warmup_pct = self.pct_start
if pct < warmup_pct:
lr = self.initial_lr + (self.max_lr - self.initial_lr) * (pct / warmup_pct)
elif pct < 1.0:
progress = (pct - warmup_pct) / (1.0 - warmup_pct)
lr = self.initial_lr + (self.max_lr - self.initial_lr) * \
(1 + math.cos(math.pi * progress)) / 2
else:
lr = self.final_lr
return [lr for _ in self.base_lrs]
def step(self):
self.step_count += 1
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
param_group['lr'] = lr
def create_schedulers(optimizer, epochs, warmup=5):
"""Create different scheduler options."""
schedulers = {
'cosine': torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=epochs
),
'warmup_cosine': WarmupCosineScheduler(
optimizer,
warmup_epochs=warmup,
total_epochs=epochs,
min_lr=1e-6
),
'cosine_restart': CosineAnnealingWarmRestarts(
optimizer,
T_0=50,
T_mult=2,
eta_min=1e-6
),
'step': torch.optim.lr_scheduler.MultiStepLR(
optimizer,
milestones=[60, 120, 160],
gamma=0.2
),
}
return schedulers
Warmup Cosine Schedule (5 epochs warmup):
97.34%
Improvement: +0.13% (cumulative: +0.96%)
Weight averaging techniques improve generalization by maintaining a running average of model weights during training.
SWA (Stochastic Weight Averaging)
- Average weights from multiple training checkpoints
- Typically start averaging after 75% of training
- Uses cyclic or constant learning rate
- Requires BN statistics update after averaging
EMA (Exponential Moving Average)
- Maintain running average throughout training
- Recent weights weighted more heavily
- No BN update needed (gradual transition)
- Simpler to implement
training/weight_averaging.py
import torch
import torch.nn as nn
from copy import deepcopy
class EMA:
"""
Exponential Moving Average of model weights.
Maintains a shadow copy of model weights as an exponentially
decaying average of the training weights.
EMA_weights = decay * EMA_weights + (1 - decay) * model_weights
Args:
model: The model to track
decay: EMA decay rate (0.999 or 0.9999 typical)
"""
def __init__(self, model, decay=0.9999):
self.model = model
self.decay = decay
self.shadow = deepcopy(model)
self.backup = None
for param in self.shadow.parameters():
param.requires_grad = False
@torch.no_grad()
def update(self):
"""Update EMA weights (call after each training step)."""
for ema_param, model_param in zip(
self.shadow.parameters(), self.model.parameters()
):
ema_param.data.mul_(self.decay).add_(
model_param.data, alpha=1 - self.decay
)
for ema_buf, model_buf in zip(
self.shadow.buffers(), self.model.buffers()
):
ema_buf.data.copy_(model_buf.data)
def apply_shadow(self):
"""Apply EMA weights to model (for evaluation)."""
self.backup = deepcopy(self.model.state_dict())
self.model.load_state_dict(self.shadow.state_dict())
def restore(self):
"""Restore original model weights (after evaluation)."""
if self.backup is not None:
self.model.load_state_dict(self.backup)
self.backup = None
class SWA:
"""
Stochastic Weight Averaging.
Maintains a simple average of model weights over multiple
checkpoints during training.
Paper: "Averaging Weights Leads to Wider Optima and Better Generalization"
Usage:
1. Train normally until SWA start epoch
2. Call update() periodically (e.g., end of each epoch)
3. After training, call update_bn() with training data
4. Use averaged model for inference
"""
def __init__(self, model):
self.model = model
self.swa_model = deepcopy(model)
self.n_averaged = 0
for param in self.swa_model.parameters():
param.data.zero_()
param.requires_grad = False
@torch.no_grad()
def update(self):
"""Add current weights to running average."""
self.n_averaged += 1
for swa_param, model_param in zip(
self.swa_model.parameters(), self.model.parameters()
):
swa_param.data.add_(
(model_param.data - swa_param.data) / self.n_averaged
)
@torch.no_grad()
def update_bn(self, train_loader, device):
"""
Update BatchNorm statistics for SWA model.
After averaging weights, BN running statistics are no longer valid.
This recomputes them using the training data.
"""
self.swa_model.train()
for module in self.swa_model.modules():
if isinstance(module, nn.BatchNorm2d):
module.reset_running_stats()
module.momentum = None
with torch.no_grad():
for inputs, _ in train_loader:
inputs = inputs.to(device)
self.swa_model(inputs)
self.swa_model.eval()
def get_model(self):
"""Get the averaged model."""
return self.swa_model
def train_with_ema(model, train_loader, test_loader, epochs, device):
"""Training loop with EMA."""
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
criterion = nn.CrossEntropyLoss()
ema = EMA(model, decay=0.9999)
best_acc = 0
for epoch in range(epochs):
model.train()
for inputs, targets in train_loader:
inputs, targets = inputs.to(device), targets.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
ema.update()
scheduler.step()
ema.apply_shadow()
model.eval()
correct = 0
total = 0
with torch.no_grad():
for inputs, targets in test_loader:
inputs, targets = inputs.to(device), targets.to(device)
outputs = model(inputs)
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
acc = 100 * correct / total
print(f'Epoch {epoch+1}: EMA Accuracy = {acc:.2f}%')
if acc > best_acc:
best_acc = acc
torch.save(model.state_dict(), 'best_ema_model.pth')
ema.restore()
return best_acc
def train_with_swa(model, train_loader, test_loader, epochs, swa_start, device):
"""Training loop with SWA (start averaging after swa_start epochs)."""
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, swa_start)
criterion = nn.CrossEntropyLoss()
swa = SWA(model)
for epoch in range(epochs):
model.train()
for inputs, targets in train_loader:
inputs, targets = inputs.to(device), targets.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
if epoch < swa_start:
scheduler.step()
else:
for param_group in optimizer.param_groups:
param_group['lr'] = 0.05
swa.update()
swa.update_bn(train_loader, device)
swa_model = swa.get_model()
swa_model.eval()
correct = 0
total = 0
with torch.no_grad():
for inputs, targets in test_loader:
inputs, targets = inputs.to(device), targets.to(device)
outputs = swa_model(inputs)
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
acc = 100 * correct / total
print(f'SWA Final Accuracy: {acc:.2f}%')
return acc
EMA (decay=0.9999):
97.52%
Improvement: +0.18% (cumulative: +1.14%)
Let's combine all tricks into a single training script.
train_full_pipeline.py
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as T
from tqdm import tqdm
import argparse
from models.pyramidnet import pyramidnet110_a270
from augmentations.cutout import Cutout
from augmentations.mixup import MixupCutMix, mixup_criterion
from training.label_smoothing import LabelSmoothingCrossEntropy
from training.schedulers import WarmupCosineScheduler
from training.weight_averaging import EMA
def get_transforms():
"""Get training and test transforms with full augmentation."""
MEAN = (0.4914, 0.4822, 0.4465)
STD = (0.2470, 0.2435, 0.2616)
train_transform = T.Compose([
T.RandomCrop(32, padding=4, padding_mode='reflect'),
T.RandomHorizontalFlip(),
T.AutoAugment(policy=T.AutoAugmentPolicy.CIFAR10),
T.ToTensor(),
T.Normalize(MEAN, STD),
Cutout(n_holes=1, length=16),
])
test_transform = T.Compose([
T.ToTensor(),
T.Normalize(MEAN, STD),
])
return train_transform, test_transform
def train_epoch(model, loader, criterion, optimizer, device, ema, mixup_fn):
"""Train for one epoch with all tricks."""
model.train()
total_loss = 0
correct = 0
total = 0
pbar = tqdm(loader, desc='Training')
for inputs, targets in pbar:
inputs, targets = inputs.to(device), targets.to(device)
if mixup_fn is not None:
inputs, targets_a, targets_b, lam = mixup_fn(inputs, targets)
outputs = model(inputs)
if mixup_fn is not None:
loss = mixup_criterion(criterion, outputs, targets_a, targets_b, lam)
else:
loss = criterion(outputs, targets)
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
ema.update()
total_loss += loss.item() * inputs.size(0)
_, predicted = outputs.max(1)
total += targets.size(0)
if mixup_fn is not None:
correct += (lam * predicted.eq(targets_a).sum().float() +
(1 - lam) * predicted.eq(targets_b).sum().float()).item()
else:
correct += predicted.eq(targets).sum().item()
pbar.set_postfix({'loss': f'{loss.item():.4f}'})
return total_loss / total, 100 * correct / total
def evaluate(model, loader, device):
"""Evaluate on test set."""
model.eval()
correct = 0
total = 0
with torch.no_grad():
for inputs, targets in loader:
inputs, targets = inputs.to(device), targets.to(device)
outputs = model(inputs)
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
return 100 * correct / total
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--epochs', type=int, default=300)
parser.add_argument('--batch_size', type=int, default=128)
parser.add_argument('--lr', type=float, default=0.1)
parser.add_argument('--warmup', type=int, default=5)
parser.add_argument('--label_smoothing', type=float, default=0.1)
parser.add_argument('--ema_decay', type=float, default=0.9999)
args = parser.parse_args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')
train_transform, test_transform = get_transforms()
train_set = torchvision.datasets.CIFAR10(
'./data', train=True, download=True, transform=train_transform
)
test_set = torchvision.datasets.CIFAR10(
'./data', train=False, download=True, transform=test_transform
)
train_loader = DataLoader(
train_set, batch_size=args.batch_size, shuffle=True,
num_workers=4, pin_memory=True
)
test_loader = DataLoader(
test_set, batch_size=args.batch_size, shuffle=False,
num_workers=4, pin_memory=True
)
model = pyramidnet110_a270(num_classes=10).to(device)
print(f'Parameters: {sum(p.numel() for p in model.parameters()):,}')
ema = EMA(model, decay=args.ema_decay)
criterion = LabelSmoothingCrossEntropy(epsilon=args.label_smoothing)
optimizer = optim.SGD(
model.parameters(),
lr=args.lr,
momentum=0.9,
weight_decay=5e-4,
nesterov=True
)
scheduler = WarmupCosineScheduler(
optimizer,
warmup_epochs=args.warmup,
total_epochs=args.epochs,
min_lr=1e-6
)
mixup_fn = MixupCutMix(mixup_alpha=0.2, cutmix_alpha=1.0)
best_acc = 0
best_ema_acc = 0
for epoch in range(args.epochs):
print(f'\nEpoch {epoch + 1}/{args.epochs} (LR: {scheduler.get_lr()[0]:.6f})')
train_loss, train_acc = train_epoch(
model, train_loader, criterion, optimizer, device, ema, mixup_fn
)
test_acc = evaluate(model, test_loader, device)
ema.apply_shadow()
ema_acc = evaluate(model, test_loader, device)
ema.restore()
print(f'Train Loss: {train_loss:.4f}')
print(f'Test Acc: {test_acc:.2f}% | EMA Acc: {ema_acc:.2f}%')
if test_acc > best_acc:
best_acc = test_acc
if ema_acc > best_ema_acc:
best_ema_acc = ema_acc
ema.apply_shadow()
torch.save(model.state_dict(), 'best_model_ema.pth')
ema.restore()
print(f'New best EMA: {best_ema_acc:.2f}%')
scheduler.step()
print(f'\nFinal Results:')
print(f'Best Regular: {best_acc:.2f}%')
print(f'Best EMA: {best_ema_acc:.2f}%')
if __name__ == '__main__':
main()
| Technique |
Accuracy |
Improvement |
| PyramidNet-110 Baseline (Part 3) |
96.38% |
- |
| + Label Smoothing (0.1) |
96.72% |
+0.34% |
| + Knowledge Distillation |
97.05% |
+0.33% |
| + Stochastic Depth |
97.21% |
+0.16% |
| + Warmup + Cosine Schedule |
97.34% |
+0.13% |
| + EMA (0.9999) |
97.52% |
+0.18% |
Updated Progress
Current: 97.52% | Target: 99.5%
Gap remaining: 1.98%
Part 4 Key Takeaways
- Label smoothing prevents overconfidence - Small epsilon (0.1) gives consistent gains
- Distillation adds "dark knowledge" - Teacher soft targets contain class relationships
- Stochastic depth regularizes deep networks - Layer dropout during training
- Warmup helps with large learning rates - Gradual start prevents early divergence
- EMA smooths noisy training - Simple but consistently improves results
- Small gains compound - Each trick adds 0.1-0.5%, totaling 1.14%
Next: Part 5 - Ensemble Methods
We're at 97.52% with a single model. In Part 5, we'll combine multiple models for even better results:
- Model Averaging - Average predictions from multiple models
- Snapshot Ensembles - Free ensemble from cyclic learning rates
- Diverse Architectures - Combine ResNet, WideResNet, PyramidNet
- Test-Time Augmentation - Multiple predictions per sample