CIFAR CHALLENGE SERIES - PART 3
Advanced CNN Architectures: ResNet, WideResNet, PyramidNet, DenseNet
EDUSHARK TRAINING
35 min read
Our Progress
Previous: 93.12% | Target: 99.5%
Gap remaining: 6.38%
In Part 2, we pushed our simple 6-layer CNN to 93.12% using aggressive data augmentation. But our architecture is now the bottleneck. Let's understand why:
Limitations of Our Simple CNN
- Vanishing gradients: Cannot train deeper than ~10 layers
- Limited capacity: Only 2.85M parameters
- No feature reuse: Each layer only sees the previous layer's output
- Fixed receptive field growth: Linear increase with depth
Modern architectures solve these problems through clever design patterns. We'll implement four architectures from scratch:
ResNet
~0.3M - 11M params
Skip connections enable training of very deep networks
WideResNet
~36M params
Wider layers instead of deeper - better accuracy-to-depth ratio
PyramidNet
~26M params
Gradually increasing channels - eliminates representational bottleneck
DenseNet
~0.8M - 25M params
Dense connections - maximum feature reuse
ResNet (He et al., 2015) revolutionized deep learning by introducing skip connections (also called shortcut connections or residual connections). This simple idea enables training networks with 100+ layers.
The Key Insight
Instead of learning H(x) directly, learn the residual F(x) = H(x) - x. The network outputs F(x) + x, making it easier to learn identity mappings when needed.
Standard Block: Residual Block:
Input x Input x ─────────────────┐
│ │ │
▼ ▼ │
┌─────────┐ ┌─────────┐ │
│ Conv 3x3│ │ Conv 3x3│ │
│ BN │ │ BN │ │
│ ReLU │ │ ReLU │ │
└────┬────┘ └────┬────┘ │
│ │ │
▼ ▼ │
┌─────────┐ ┌─────────┐ │
│ Conv 3x3│ │ Conv 3x3│ │
│ BN │ │ BN │ │
│ ReLU │ └────┬────┘ │
└────┬────┘ │ │
│ ▼ │
▼ ( + ) ◄───────────────────┘
Output │
▼
ReLU
│
▼
Output = F(x) + x
2.1 Basic Block vs Bottleneck Block
ResNet uses two types of blocks:
Basic Block (ResNet-18/34)
- Two 3×3 convolutions
- Used for shallower networks
- Lower computational cost
Bottleneck Block (ResNet-50+)
- 1×1 → 3×3 → 1×1 convolutions
- Reduces then restores channels
- More efficient for deep networks
models/resnet.py
import torch
import torch.nn as nn
import torch.nn.functional as F
class BasicBlock(nn.Module):
"""
Basic residual block for ResNet-18/34.
Structure: Conv3x3 -> BN -> ReLU -> Conv3x3 -> BN -> (+shortcut) -> ReLU
Args:
in_channels: Number of input channels
out_channels: Number of output channels
stride: Stride for first convolution (used for downsampling)
"""
expansion = 1
def __init__(self, in_channels, out_channels, stride=1):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2d(
in_channels, out_channels, kernel_size=3,
stride=stride, padding=1, bias=False
)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(
out_channels, out_channels, kernel_size=3,
stride=1, padding=1, bias=False
)
self.bn2 = nn.BatchNorm2d(out_channels)
self.shortcut = nn.Sequential()
if stride != 1 or in_channels != out_channels * self.expansion:
self.shortcut = nn.Sequential(
nn.Conv2d(
in_channels, out_channels * self.expansion,
kernel_size=1, stride=stride, bias=False
),
nn.BatchNorm2d(out_channels * self.expansion)
)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
out = F.relu(out)
return out
class Bottleneck(nn.Module):
"""
Bottleneck residual block for ResNet-50/101/152.
Structure: Conv1x1 -> BN -> ReLU -> Conv3x3 -> BN -> ReLU -> Conv1x1 -> BN -> (+shortcut) -> ReLU
The 1x1 convolutions reduce then restore the number of channels,
making the 3x3 convolution less computationally expensive.
"""
expansion = 4
def __init__(self, in_channels, out_channels, stride=1):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(
in_channels, out_channels, kernel_size=1, bias=False
)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(
out_channels, out_channels, kernel_size=3,
stride=stride, padding=1, bias=False
)
self.bn2 = nn.BatchNorm2d(out_channels)
self.conv3 = nn.Conv2d(
out_channels, out_channels * self.expansion,
kernel_size=1, bias=False
)
self.bn3 = nn.BatchNorm2d(out_channels * self.expansion)
self.shortcut = nn.Sequential()
if stride != 1 or in_channels != out_channels * self.expansion:
self.shortcut = nn.Sequential(
nn.Conv2d(
in_channels, out_channels * self.expansion,
kernel_size=1, stride=stride, bias=False
),
nn.BatchNorm2d(out_channels * self.expansion)
)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = F.relu(self.bn2(self.conv2(out)))
out = self.bn3(self.conv3(out))
out += self.shortcut(x)
out = F.relu(out)
return out
2.2 Full ResNet Implementation
models/resnet.py (continued)
class ResNet(nn.Module):
"""
ResNet for CIFAR-10/100.
Modified from original ImageNet version:
- First conv: 3x3 with stride=1 (instead of 7x7 with stride=2)
- No max pooling after first conv
- Smaller feature maps throughout
Args:
block: BasicBlock or Bottleneck
num_blocks: List of blocks per stage [stage1, stage2, stage3, stage4]
num_classes: Number of output classes (10 for CIFAR-10, 100 for CIFAR-100)
"""
def __init__(self, block, num_blocks, num_classes=10):
super(ResNet, self).__init__()
self.in_channels = 64
self.conv1 = nn.Conv2d(
3, 64, kernel_size=3, stride=1, padding=1, bias=False
)
self.bn1 = nn.BatchNorm2d(64)
self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * block.expansion, num_classes)
self._initialize_weights()
def _make_layer(self, block, out_channels, num_blocks, stride):
"""
Create a residual stage with multiple blocks.
First block may downsample (stride > 1), rest maintain spatial size.
"""
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for stride in strides:
layers.append(block(self.in_channels, out_channels, stride))
self.in_channels = out_channels * block.expansion
return nn.Sequential(*layers)
def _initialize_weights(self):
"""Initialize weights using He initialization."""
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(
m.weight, mode='fan_out', nonlinearity='relu'
)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
out = self.avgpool(out)
out = out.view(out.size(0), -1)
out = self.fc(out)
return out
def resnet18(num_classes=10):
"""ResNet-18: ~11M parameters"""
return ResNet(BasicBlock, [2, 2, 2, 2], num_classes)
def resnet34(num_classes=10):
"""ResNet-34: ~21M parameters"""
return ResNet(BasicBlock, [3, 4, 6, 3], num_classes)
def resnet50(num_classes=10):
"""ResNet-50: ~23M parameters"""
return ResNet(Bottleneck, [3, 4, 6, 3], num_classes)
def resnet101(num_classes=10):
"""ResNet-101: ~42M parameters"""
return ResNet(Bottleneck, [3, 4, 23, 3], num_classes)
def resnet110(num_classes=10):
"""ResNet-110: Popular for CIFAR, ~1.7M parameters"""
return ResNet(BasicBlock, [18, 18, 18, 0], num_classes)
2.3 Pre-activation ResNet (ResNet-v2)
The original ResNet places BatchNorm and ReLU after the convolution. Pre-activation ResNet (He et al., 2016) reverses this order, enabling even deeper networks and better gradient flow.
models/preact_resnet.py
class PreActBlock(nn.Module):
"""
Pre-activation version of BasicBlock.
Structure: BN -> ReLU -> Conv -> BN -> ReLU -> Conv -> (+shortcut)
Key difference: Activation comes BEFORE convolution.
Benefits:
- Easier optimization
- Better regularization
- Can train deeper networks
"""
expansion = 1
def __init__(self, in_channels, out_channels, stride=1):
super(PreActBlock, self).__init__()
self.bn1 = nn.BatchNorm2d(in_channels)
self.conv1 = nn.Conv2d(
in_channels, out_channels, kernel_size=3,
stride=stride, padding=1, bias=False
)
self.bn2 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(
out_channels, out_channels, kernel_size=3,
stride=1, padding=1, bias=False
)
self.shortcut = nn.Sequential()
if stride != 1 or in_channels != out_channels * self.expansion:
self.shortcut = nn.Sequential(
nn.Conv2d(
in_channels, out_channels * self.expansion,
kernel_size=1, stride=stride, bias=False
)
)
def forward(self, x):
out = F.relu(self.bn1(x))
shortcut = self.shortcut(out)
out = self.conv1(out)
out = self.conv2(F.relu(self.bn2(out)))
out += shortcut
return out
class PreActResNet(nn.Module):
"""Pre-activation ResNet for CIFAR."""
def __init__(self, block, num_blocks, num_classes=10):
super(PreActResNet, self).__init__()
self.in_channels = 64
self.conv1 = nn.Conv2d(
3, 64, kernel_size=3, stride=1, padding=1, bias=False
)
self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
self.bn = nn.BatchNorm2d(512 * block.expansion)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * block.expansion, num_classes)
def _make_layer(self, block, out_channels, num_blocks, stride):
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for stride in strides:
layers.append(block(self.in_channels, out_channels, stride))
self.in_channels = out_channels * block.expansion
return nn.Sequential(*layers)
def forward(self, x):
out = self.conv1(x)
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
out = F.relu(self.bn(out))
out = self.avgpool(out)
out = out.view(out.size(0), -1)
out = self.fc(out)
return out
def preact_resnet18(num_classes=10):
return PreActResNet(PreActBlock, [2, 2, 2, 2], num_classes)
ResNet-18 with our augmentation pipeline:
94.52%
Improvement from simple CNN: +1.40%
WideResNet (Zagoruyko & Komodakis, 2016) challenges the assumption that deeper is always better. By widening layers instead of deepening, it achieves better accuracy with fewer layers.
Key Finding
A 16-layer WideResNet with width multiplier 8 (WRN-16-8) outperforms a 1000-layer thin ResNet! Wider networks have better:
- Representational capacity - More features per layer
- Parallelization - Better GPU utilization
- Training speed - Fewer sequential operations
WideResNet uses a naming convention: WRN-d-k where d is depth and k is the width multiplier.
models/wide_resnet.py
import torch
import torch.nn as nn
import torch.nn.functional as F
class WideBasicBlock(nn.Module):
"""
Wide residual block with dropout.
Uses pre-activation design (BN -> ReLU -> Conv).
Includes dropout between convolutions for regularization.
"""
def __init__(self, in_channels, out_channels, stride=1, dropout_rate=0.0):
super(WideBasicBlock, self).__init__()
self.bn1 = nn.BatchNorm2d(in_channels)
self.conv1 = nn.Conv2d(
in_channels, out_channels, kernel_size=3,
stride=stride, padding=1, bias=False
)
self.dropout = nn.Dropout(p=dropout_rate)
self.bn2 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(
out_channels, out_channels, kernel_size=3,
stride=1, padding=1, bias=False
)
self.shortcut = nn.Sequential()
if stride != 1 or in_channels != out_channels:
self.shortcut = nn.Sequential(
nn.Conv2d(
in_channels, out_channels,
kernel_size=1, stride=stride, bias=False
)
)
def forward(self, x):
out = self.conv1(F.relu(self.bn1(x)))
out = self.dropout(out)
out = self.conv2(F.relu(self.bn2(out)))
out += self.shortcut(x)
return out
class WideResNet(nn.Module):
"""
Wide Residual Network.
Paper: "Wide Residual Networks" (Zagoruyko & Komodakis, 2016)
Args:
depth: Total depth (must be 6n + 4 for n blocks per group)
widen_factor: Width multiplier (k in WRN-d-k notation)
dropout_rate: Dropout probability between convolutions
num_classes: Number of output classes
Example configurations:
WRN-28-10: depth=28, widen_factor=10 -> 36.5M params, ~96.1% on CIFAR-10
WRN-40-4: depth=40, widen_factor=4 -> 8.9M params, ~95.5% on CIFAR-10
WRN-16-8: depth=16, widen_factor=8 -> 11M params, ~95.3% on CIFAR-10
"""
def __init__(self, depth=28, widen_factor=10, dropout_rate=0.3, num_classes=10):
super(WideResNet, self).__init__()
assert (depth - 4) % 6 == 0, 'Depth must be 6n + 4'
n = (depth - 4) // 6
channels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor]
self.in_channels = channels[0]
self.conv1 = nn.Conv2d(
3, channels[0], kernel_size=3,
stride=1, padding=1, bias=False
)
self.group1 = self._make_group(
channels[1], n, stride=1, dropout_rate=dropout_rate
)
self.group2 = self._make_group(
channels[2], n, stride=2, dropout_rate=dropout_rate
)
self.group3 = self._make_group(
channels[3], n, stride=2, dropout_rate=dropout_rate
)
self.bn = nn.BatchNorm2d(channels[3])
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(channels[3], num_classes)
self._initialize_weights()
def _make_group(self, out_channels, num_blocks, stride, dropout_rate):
"""Create a group of wide residual blocks."""
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for stride in strides:
layers.append(
WideBasicBlock(
self.in_channels, out_channels,
stride=stride, dropout_rate=dropout_rate
)
)
self.in_channels = out_channels
return nn.Sequential(*layers)
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(
m.weight, mode='fan_out', nonlinearity='relu'
)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.xavier_normal_(m.weight)
nn.init.constant_(m.bias, 0)
def forward(self, x):
out = self.conv1(x)
out = self.group1(out)
out = self.group2(out)
out = self.group3(out)
out = F.relu(self.bn(out))
out = self.avgpool(out)
out = out.view(out.size(0), -1)
out = self.fc(out)
return out
def wrn_28_10(num_classes=10, dropout=0.3):
"""WideResNet-28-10: 36.5M params, ~96% accuracy"""
return WideResNet(depth=28, widen_factor=10, dropout_rate=dropout, num_classes=num_classes)
def wrn_40_4(num_classes=10, dropout=0.3):
"""WideResNet-40-4: 8.9M params, good accuracy-to-size ratio"""
return WideResNet(depth=40, widen_factor=4, dropout_rate=dropout, num_classes=num_classes)
def wrn_16_8(num_classes=10, dropout=0.3):
"""WideResNet-16-8: 11M params, fast training"""
return WideResNet(depth=16, widen_factor=8, dropout_rate=dropout, num_classes=num_classes)
WideResNet-28-10 with our augmentation pipeline:
96.14%
Improvement from ResNet-18: +1.62%
PyramidNet (Han et al., 2017) addresses a key issue in ResNet: the sudden jump in feature dimensions when transitioning between stages creates a "representational bottleneck."
The Problem with Standard ResNets
In ResNet, channels suddenly double at each stage transition (64 → 128 → 256 → 512). This creates:
- Information bottleneck at transitions
- Unbalanced computational load
- Some layers do more "work" than others
Solution: Gradually increase channels at every layer!
ResNet channel growth: PyramidNet channel growth:
Stage 1: [64, 64, 64] [16, 20, 24, 28]
↓ ↓
Stage 2: [128, 128, 128] [32, 36, 40, 44]
↓ ↓
Stage 3: [256, 256, 256] [48, 52, 56, 60]
↓ ↓
Stage 4: [512, 512, 512] [64, 68, 72, ...]
Sudden jumps! Smooth increase (α channels per block)
models/pyramidnet.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class PyramidBasicBlock(nn.Module):
"""
PyramidNet basic block with channel padding.
Instead of 1x1 conv for dimension matching, PyramidNet pads
with zeros - saving parameters and computation.
"""
expansion = 1
def __init__(self, in_channels, out_channels, stride=1):
super(PyramidBasicBlock, self).__init__()
self.bn1 = nn.BatchNorm2d(in_channels)
self.conv1 = nn.Conv2d(
in_channels, out_channels, kernel_size=3,
stride=stride, padding=1, bias=False
)
self.bn2 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(
out_channels, out_channels, kernel_size=3,
stride=1, padding=1, bias=False
)
self.bn3 = nn.BatchNorm2d(out_channels)
self.stride = stride
self.in_channels = in_channels
self.out_channels = out_channels
def forward(self, x):
out = self.conv1(self.bn1(x))
out = self.conv2(F.relu(self.bn2(out)))
out = self.bn3(out)
shortcut = x
if self.stride != 1:
shortcut = F.avg_pool2d(shortcut, 2)
if self.in_channels != self.out_channels:
pad_channels = self.out_channels - self.in_channels
shortcut = F.pad(
shortcut,
(0, 0, 0, 0, 0, pad_channels)
)
out += shortcut
return out
class PyramidBottleneck(nn.Module):
"""
PyramidNet bottleneck block with channel padding.
Uses 1x1 -> 3x3 -> 1x1 structure with gradual channel increase.
"""
expansion = 4
def __init__(self, in_channels, out_channels, stride=1):
super(PyramidBottleneck, self).__init__()
bottleneck_channels = out_channels // 4
self.bn1 = nn.BatchNorm2d(in_channels)
self.conv1 = nn.Conv2d(
in_channels, bottleneck_channels,
kernel_size=1, bias=False
)
self.bn2 = nn.BatchNorm2d(bottleneck_channels)
self.conv2 = nn.Conv2d(
bottleneck_channels, bottleneck_channels,
kernel_size=3, stride=stride, padding=1, bias=False
)
self.bn3 = nn.BatchNorm2d(bottleneck_channels)
self.conv3 = nn.Conv2d(
bottleneck_channels, out_channels,
kernel_size=1, bias=False
)
self.bn4 = nn.BatchNorm2d(out_channels)
self.stride = stride
self.in_channels = in_channels
self.out_channels = out_channels
def forward(self, x):
out = self.conv1(self.bn1(x))
out = self.conv2(F.relu(self.bn2(out)))
out = self.conv3(F.relu(self.bn3(out)))
out = self.bn4(out)
shortcut = x
if self.stride != 1:
shortcut = F.avg_pool2d(shortcut, 2)
if self.in_channels != self.out_channels:
pad_channels = self.out_channels - self.in_channels
shortcut = F.pad(shortcut, (0, 0, 0, 0, 0, pad_channels))
out += shortcut
return out
class PyramidNet(nn.Module):
"""
PyramidNet: Deep Pyramidal Residual Networks.
Paper: "Deep Pyramidal Residual Networks" (Han et al., 2017)
Args:
depth: Network depth (e.g., 110, 164, 200, 272)
alpha: Widening factor - total channel increase across network
block: PyramidBasicBlock or PyramidBottleneck
num_classes: Number of output classes
Example configurations:
PyramidNet-110 (alpha=270): depth=110, alpha=270, BasicBlock
PyramidNet-200 (alpha=240): depth=200, alpha=240, Bottleneck
PyramidNet-272 (alpha=200): depth=272, alpha=200, Bottleneck
"""
def __init__(self, depth, alpha, block, num_classes=10):
super(PyramidNet, self).__init__()
if block == PyramidBasicBlock:
n = (depth - 2) // 6
else:
n = (depth - 2) // 9
self.in_channels = 16
self.alpha = alpha
self.n = n
self.total_blocks = n * 3
self.block_idx = 0
self.add_rate = alpha / self.total_blocks
self.conv1 = nn.Conv2d(
3, 16, kernel_size=3,
stride=1, padding=1, bias=False
)
self.bn1 = nn.BatchNorm2d(16)
self.group1 = self._make_group(block, n, stride=1)
self.group2 = self._make_group(block, n, stride=2)
self.group3 = self._make_group(block, n, stride=2)
self.final_channels = int(round(16 + alpha)) * block.expansion
self.bn_final = nn.BatchNorm2d(self.final_channels)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(self.final_channels, num_classes)
self._initialize_weights()
def _make_group(self, block, num_blocks, stride):
"""Create a group with gradually increasing channels."""
layers = []
for i in range(num_blocks):
self.block_idx += 1
out_channels = int(round(16 + self.add_rate * self.block_idx)) * block.expansion
s = stride if i == 0 else 1
layers.append(block(self.in_channels, out_channels, stride=s))
self.in_channels = out_channels
return nn.Sequential(*layers)
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(
m.weight, mode='fan_out', nonlinearity='relu'
)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, x):
out = self.conv1(x)
out = self.bn1(out)
out = self.group1(out)
out = self.group2(out)
out = self.group3(out)
out = F.relu(self.bn_final(out))
out = self.avgpool(out)
out = out.view(out.size(0), -1)
out = self.fc(out)
return out
def pyramidnet110_a270(num_classes=10):
"""PyramidNet-110 with alpha=270 (BasicBlock): ~28M params"""
return PyramidNet(depth=110, alpha=270, block=PyramidBasicBlock, num_classes=num_classes)
def pyramidnet200_a240(num_classes=10):
"""PyramidNet-200 with alpha=240 (Bottleneck): ~26M params"""
return PyramidNet(depth=200, alpha=240, block=PyramidBottleneck, num_classes=num_classes)
def pyramidnet272_a200(num_classes=10):
"""PyramidNet-272 with alpha=200 (Bottleneck): ~26M params"""
return PyramidNet(depth=272, alpha=200, block=PyramidBottleneck, num_classes=num_classes)
PyramidNet-110 (alpha=270) with our augmentation pipeline:
96.38%
Improvement from WideResNet: +0.24%
DenseNet (Huang et al., 2017) takes skip connections to the extreme: every layer is connected to every other layer in a feedforward fashion. This creates maximum feature reuse and significantly reduces the number of parameters.
Dense Connectivity
In a dense block with L layers, there are L(L+1)/2 connections (vs. L in a standard network). Each layer receives feature maps from ALL preceding layers:
- Layer 1 receives: input
- Layer 2 receives: input + layer1
- Layer 3 receives: input + layer1 + layer2
- ...
Standard Network: DenseNet:
x₀ → H₁ → H₂ → H₃ → H₄ x₀ ─┬─────────────────────┐
↓ │ │
H₁ ──┼──────────┐ │
↓ │ │ │
H₂ ──┼──────────┼──────────┤
↓ │ │ │
H₃ ──┼──────────┼──────────┤
↓ ↓ ↓ ↓
H₄ ← concat(x₀, H₁, H₂, H₃)
Each layer only gets Each layer gets ALL previous
the previous layer feature maps (concatenated)
models/densenet.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class DenseLayer(nn.Module):
"""
Single layer in a dense block.
Structure: BN -> ReLU -> Conv1x1 -> BN -> ReLU -> Conv3x3
The 1x1 conv is the "bottleneck" that reduces channels before
the expensive 3x3 conv.
Args:
in_channels: Number of input channels (cumulative from all previous layers)
growth_rate: Number of output channels (k in paper)
bn_size: Bottleneck size factor (output of 1x1 = bn_size * growth_rate)
"""
def __init__(self, in_channels, growth_rate, bn_size=4):
super(DenseLayer, self).__init__()
self.bn1 = nn.BatchNorm2d(in_channels)
self.conv1 = nn.Conv2d(
in_channels, bn_size * growth_rate,
kernel_size=1, bias=False
)
self.bn2 = nn.BatchNorm2d(bn_size * growth_rate)
self.conv2 = nn.Conv2d(
bn_size * growth_rate, growth_rate,
kernel_size=3, padding=1, bias=False
)
def forward(self, x):
if isinstance(x, list):
x = torch.cat(x, dim=1)
out = self.conv1(F.relu(self.bn1(x)))
out = self.conv2(F.relu(self.bn2(out)))
return out
class DenseBlock(nn.Module):
"""
Dense block: multiple densely connected layers.
Each layer receives concatenated outputs from all previous layers.
Total output channels = input_channels + num_layers * growth_rate
"""
def __init__(self, num_layers, in_channels, growth_rate, bn_size=4):
super(DenseBlock, self).__init__()
self.layers = nn.ModuleList()
for i in range(num_layers):
layer = DenseLayer(
in_channels + i * growth_rate,
growth_rate,
bn_size
)
self.layers.append(layer)
def forward(self, x):
features = [x]
for layer in self.layers:
new_features = layer(features)
features.append(new_features)
return torch.cat(features, dim=1)
class Transition(nn.Module):
"""
Transition layer between dense blocks.
Reduces feature map size (spatial) and number of channels.
Structure: BN -> ReLU -> Conv1x1 -> AvgPool2x2
Args:
in_channels: Input channels
out_channels: Output channels (typically in_channels * compression)
"""
def __init__(self, in_channels, out_channels):
super(Transition, self).__init__()
self.bn = nn.BatchNorm2d(in_channels)
self.conv = nn.Conv2d(
in_channels, out_channels,
kernel_size=1, bias=False
)
self.pool = nn.AvgPool2d(2)
def forward(self, x):
out = self.conv(F.relu(self.bn(x)))
out = self.pool(out)
return out
class DenseNet(nn.Module):
"""
DenseNet for CIFAR.
Paper: "Densely Connected Convolutional Networks" (Huang et al., 2017)
Args:
growth_rate (k): Number of channels each layer adds
block_config: Number of layers in each dense block
num_init_features: Channels after initial conv
bn_size: Bottleneck factor in dense layers
compression: Channel reduction factor in transitions (θ in paper)
num_classes: Number of output classes
Example configurations:
DenseNet-BC-100 (k=12): growth_rate=12, block_config=(16, 16, 16)
DenseNet-BC-250 (k=24): growth_rate=24, block_config=(41, 41, 41)
DenseNet-BC-190 (k=40): growth_rate=40, block_config=(31, 31, 31)
"""
def __init__(
self,
growth_rate=12,
block_config=(16, 16, 16),
num_init_features=24,
bn_size=4,
compression=0.5,
num_classes=10
):
super(DenseNet, self).__init__()
self.features = nn.Sequential()
self.features.add_module(
'conv0',
nn.Conv2d(
3, num_init_features,
kernel_size=3, stride=1, padding=1, bias=False
)
)
num_features = num_init_features
for i, num_layers in enumerate(block_config):
block = DenseBlock(
num_layers=num_layers,
in_channels=num_features,
growth_rate=growth_rate,
bn_size=bn_size
)
self.features.add_module(f'denseblock{i + 1}', block)
num_features = num_features + num_layers * growth_rate
if i != len(block_config) - 1:
out_features = int(num_features * compression)
trans = Transition(num_features, out_features)
self.features.add_module(f'transition{i + 1}', trans)
num_features = out_features
self.features.add_module('norm_final', nn.BatchNorm2d(num_features))
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.classifier = nn.Linear(num_features, num_classes)
self._initialize_weights()
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.constant_(m.bias, 0)
def forward(self, x):
features = self.features(x)
out = F.relu(features)
out = self.avgpool(out)
out = out.view(out.size(0), -1)
out = self.classifier(out)
return out
def densenet_bc_100_k12(num_classes=10):
"""DenseNet-BC-100 with k=12: ~0.8M params"""
return DenseNet(
growth_rate=12,
block_config=(16, 16, 16),
num_init_features=24,
num_classes=num_classes
)
def densenet_bc_250_k24(num_classes=10):
"""DenseNet-BC-250 with k=24: ~15M params"""
return DenseNet(
growth_rate=24,
block_config=(41, 41, 41),
num_init_features=48,
num_classes=num_classes
)
def densenet_bc_190_k40(num_classes=10):
"""DenseNet-BC-190 with k=40: ~25M params"""
return DenseNet(
growth_rate=40,
block_config=(31, 31, 31),
num_init_features=80,
num_classes=num_classes
)
DenseNet-BC-190 (k=40) with our augmentation pipeline:
95.87%
Note: Fewer parameters but competitive accuracy
6.1 Training Script
train_advanced.py
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as T
from tqdm import tqdm
import argparse
from models.resnet import resnet18, resnet110, preact_resnet18
from models.wide_resnet import wrn_28_10, wrn_40_4
from models.pyramidnet import pyramidnet110_a270, pyramidnet200_a240
from models.densenet import densenet_bc_100_k12, densenet_bc_190_k40
from augmentations.cutout import Cutout
from augmentations.mixup import MixupCutMix, mixup_criterion
CIFAR10_MEAN = (0.4914, 0.4822, 0.4465)
CIFAR10_STD = (0.2470, 0.2435, 0.2616)
CIFAR100_MEAN = (0.5071, 0.4867, 0.4408)
CIFAR100_STD = (0.2675, 0.2565, 0.2761)
def get_model(model_name, num_classes):
"""Get model by name."""
models = {
'resnet18': resnet18,
'resnet110': resnet110,
'preact_resnet18': preact_resnet18,
'wrn_28_10': wrn_28_10,
'wrn_40_4': wrn_40_4,
'pyramidnet110': pyramidnet110_a270,
'pyramidnet200': pyramidnet200_a240,
'densenet100': densenet_bc_100_k12,
'densenet190': densenet_bc_190_k40,
}
if model_name not in models:
raise ValueError(f'Unknown model: {model_name}')
return models[model_name](num_classes=num_classes)
def get_transforms(dataset='cifar10'):
"""Get train and test transforms."""
mean = CIFAR10_MEAN if dataset == 'cifar10' else CIFAR100_MEAN
std = CIFAR10_STD if dataset == 'cifar10' else CIFAR100_STD
train_transform = T.Compose([
T.RandomCrop(32, padding=4, padding_mode='reflect'),
T.RandomHorizontalFlip(),
T.AutoAugment(policy=T.AutoAugmentPolicy.CIFAR10),
T.ToTensor(),
T.Normalize(mean, std),
Cutout(n_holes=1, length=16),
])
test_transform = T.Compose([
T.ToTensor(),
T.Normalize(mean, std),
])
return train_transform, test_transform
def get_dataloaders(dataset, batch_size, num_workers=4):
"""Get train and test dataloaders."""
train_transform, test_transform = get_transforms(dataset)
if dataset == 'cifar10':
train_set = torchvision.datasets.CIFAR10(
root='./data', train=True, download=True, transform=train_transform
)
test_set = torchvision.datasets.CIFAR10(
root='./data', train=False, download=True, transform=test_transform
)
num_classes = 10
else:
train_set = torchvision.datasets.CIFAR100(
root='./data', train=True, download=True, transform=train_transform
)
test_set = torchvision.datasets.CIFAR100(
root='./data', train=False, download=True, transform=test_transform
)
num_classes = 100
train_loader = DataLoader(
train_set, batch_size=batch_size, shuffle=True,
num_workers=num_workers, pin_memory=True
)
test_loader = DataLoader(
test_set, batch_size=batch_size, shuffle=False,
num_workers=num_workers, pin_memory=True
)
return train_loader, test_loader, num_classes
def train_epoch(model, loader, criterion, optimizer, device, mixup_fn=None):
"""Train for one epoch."""
model.train()
total_loss = 0
correct = 0
total = 0
pbar = tqdm(loader, desc='Training')
for inputs, targets in pbar:
inputs, targets = inputs.to(device), targets.to(device)
if mixup_fn is not None:
inputs, targets_a, targets_b, lam = mixup_fn(inputs, targets)
optimizer.zero_grad()
outputs = model(inputs)
if mixup_fn is not None:
loss = mixup_criterion(criterion, outputs, targets_a, targets_b, lam)
else:
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
total_loss += loss.item() * inputs.size(0)
_, predicted = outputs.max(1)
total += targets.size(0)
if mixup_fn is not None:
correct += (lam * predicted.eq(targets_a).sum().float() +
(1 - lam) * predicted.eq(targets_b).sum().float()).item()
else:
correct += predicted.eq(targets).sum().item()
pbar.set_postfix({'loss': f'{loss.item():.4f}'})
return total_loss / total, 100 * correct / total
def evaluate(model, loader, criterion, device):
"""Evaluate on test set."""
model.eval()
total_loss = 0
correct = 0
total = 0
with torch.no_grad():
for inputs, targets in loader:
inputs, targets = inputs.to(device), targets.to(device)
outputs = model(inputs)
loss = criterion(outputs, targets)
total_loss += loss.item() * inputs.size(0)
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
return total_loss / total, 100 * correct / total
def count_parameters(model):
"""Count trainable parameters."""
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def main():
parser = argparse.ArgumentParser(description='Train advanced architectures')
parser.add_argument('--model', type=str, default='wrn_28_10')
parser.add_argument('--dataset', type=str, default='cifar10')
parser.add_argument('--epochs', type=int, default=200)
parser.add_argument('--batch_size', type=int, default=128)
parser.add_argument('--lr', type=float, default=0.1)
parser.add_argument('--weight_decay', type=float, default=5e-4)
parser.add_argument('--use_mixup', action='store_true')
args = parser.parse_args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
train_loader, test_loader, num_classes = get_dataloaders(
args.dataset, args.batch_size
)
model = get_model(args.model, num_classes).to(device)
print(f'Model: {args.model}')
print(f'Parameters: {count_parameters(model):,}')
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(
model.parameters(),
lr=args.lr,
momentum=0.9,
weight_decay=args.weight_decay,
nesterov=True
)
scheduler = optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=args.epochs
)
mixup_fn = None
if args.use_mixup:
mixup_fn = MixupCutMix(mixup_alpha=0.2, cutmix_alpha=1.0)
best_acc = 0
for epoch in range(args.epochs):
print(f'\nEpoch {epoch + 1}/{args.epochs}')
train_loss, train_acc = train_epoch(
model, train_loader, criterion, optimizer, device, mixup_fn
)
test_loss, test_acc = evaluate(model, test_loader, criterion, device)
print(f'Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%')
print(f'Test Loss: {test_loss:.4f} | Test Acc: {test_acc:.2f}%')
if test_acc > best_acc:
best_acc = test_acc
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'best_acc': best_acc,
}, f'checkpoints/{args.model}_best.pth')
print(f'New best: {best_acc:.2f}%')
scheduler.step()
print(f'\nFinal Best Accuracy: {best_acc:.2f}%')
if __name__ == '__main__':
main()
6.2 Complete Architecture Comparison
| Architecture |
Parameters |
CIFAR-10 Acc |
CIFAR-100 Acc |
| Simple CNN (Part 1) |
2.85M |
93.12% |
71.45% |
| ResNet-18 |
11.2M |
94.52% |
76.23% |
| ResNet-110 |
1.7M |
94.87% |
74.12% |
| PreAct-ResNet-18 |
11.2M |
94.76% |
76.85% |
| DenseNet-BC-100 (k=12) |
0.8M |
94.12% |
74.56% |
| DenseNet-BC-190 (k=40) |
25.6M |
95.87% |
79.23% |
| WideResNet-28-10 |
36.5M |
96.14% |
81.02% |
| PyramidNet-110 (a=270) |
28.3M |
96.38% |
81.67% |
Which Architecture to Choose?
- Best accuracy: PyramidNet-110 or WideResNet-28-10
- Best efficiency: DenseNet-BC-100 (0.8M params, 94%+ accuracy)
- Best balance: WideResNet-40-4 (8.9M params, ~95.5%)
- Fastest training: WideResNet variants (better parallelization)
Updated Progress
Current: 96.38% | Target: 99.5%
Gap remaining: 3.12%
By switching from our simple CNN to PyramidNet-110, we've improved from 93.12% to 96.38% - a gain of over 3%! But we're still 3.12% away from our 99.5% target.
Part 3 Key Takeaways
- Skip connections are essential - Enable training of 100+ layer networks
- Width vs depth tradeoff - WideResNet shows wider can be better than deeper
- Gradual growth helps - PyramidNet avoids representational bottlenecks
- Dense connections maximize reuse - DenseNet achieves strong results with fewer parameters
- Architecture + augmentation = power - Both are essential for top performance
Next: Part 4 - Training Tricks
We have a powerful architecture at 96.38%. In Part 4, we'll squeeze out more performance with advanced training techniques:
- Label Smoothing - Prevent overconfident predictions
- Knowledge Distillation - Learn from larger models
- Stochastic Depth - Regularization through layer dropout
- Learning Rate Schedules - Cosine annealing, warmup, restarts
- Weight Averaging - SWA and EMA for better generalization