Training a classifier using PyTorch results in a channel mismatch error during the forward pass of the model.
The RuntimeError is unusual because the input data shape correctly matches the expected [batch_size, 3, 32, 32] format for CIFAR-10 images, yet the model weirdly anticipates 64 input channels at some point.
Error Message:
RuntimeError: Given groups=1, weight of size [64, 64, 3, 3], expected input[64, 3, 32, 32] to have 64 channels, but got 3 channels instead.
Data Preparation and Loading:
import torch
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
def unpickle(file):
import pickle
with open(file, 'rb') as fo:
dict = pickle.load(fo, encoding='bytes')
return dict
train_data = []
train_labels = []
for i in range(1, 6):
batch = unpickle(f'data_batch_{i}')
train_data.append(batch[b'data'])
train_labels += batch[b'labels']
train_data = np.vstack(train_data).reshape(-1, 3, 32, 32).transpose(0, 2, 3, 1)
train_labels = np.array(train_labels)
test_batch = unpickle('test_batch')
test_data = test_batch[b'data'].reshape(-1, 3, 32, 32).transpose(0, 2, 3, 1)
test_labels = np.array(test_batch[b'labels'])
train_data_tensor = torch.tensor(train_data.transpose((0, 3, 1, 2)), dtype=torch.float) / 255.0
train_labels_tensor = torch.tensor(train_labels, dtype=torch.long)
test_data_tensor = torch.tensor(test_data.transpose((0, 3, 1, 2)), dtype=torch.float) / 255.0
test_labels_tensor = torch.tensor(test_labels, dtype=torch.long)
train_loader = DataLoader(TensorDataset(train_data_tensor, train_labels_tensor), batch_size=64, shuffle=True)
test_loader = DataLoader(TensorDataset(test_data_tensor, test_labels_tensor), batch_size=64, shuffle=False)
The shape of the inputs in the first batch after the DataLoader is: torch.Size([64, 3, 32, 32]).
Model Definition:
import torch.nn as nn
import torch.nn.functional as F
class ImprovedIntermediateBlock(nn.Module):
def __init__(self, in_channels, out_channels, num_convs=3):
super(ImprovedIntermediateBlock, self).__init__()
self.convs = nn.ModuleList([
nn.Sequential(
nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
nn.Dropout2d(0.2)
) for i in range(num_convs)
])
self.fc = nn.Linear(in_channels, num_convs)
def forward(self, x):
channel_means = x.mean([-2, -1])
weights = F.softmax(self.fc(channel_means), dim=1)
conv_outputs = torch.stack([conv(x) for conv in self.convs], dim=1)
output = torch.sum(weights.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) * conv_outputs, dim=1)
return output
class EnhancedOutputBlock(nn.Module):
def __init__(self, in_channels, num_classes):
super(EnhancedOutputBlock, self).__init__()
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc1 = nn.Linear(in_channels, in_channels // 2)
self.fc2 = nn.Linear(in_channels // 2, num_classes)
self.dropout = nn.Dropout(0.5)
def forward(self, x):
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = F.relu(self.fc1(x))
x = self.dropout(x)
x = self.fc2(x)
return x
class CIFAR10Classifier(nn.Module):
def __init__(self, num_blocks=2, num_classes=10):
super(CIFAR10Classifier, self).__init__()
self.blocks = nn.Sequential()
in_channels = 3
out_channels = 64
for i in range(num_blocks):
self.blocks.add_module(f"block_{i+1}", ImprovedIntermediateBlock(in_channels, out_channels))
in_channels = out_channels
self.output_block = EnhancedOutputBlock(out_channels, num_classes)
def forward(self, x):
x = self.blocks(x)
logits = self.output_block(x)
return logits
model = CIFAR10Classifier().to(device)
Model output:
CIFAR10Classifier(
(blocks): Sequential(
(block_1): ImprovedIntermediateBlock(
(convs): ModuleList(
(0): Sequential(
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
(3): Dropout2d(p=0.2, inplace=False)
)
(1-2): 2 x Sequential(
(0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
(3): Dropout2d(p=0.2, inplace=False)
)
)
(fc): Linear(in_features=3, out_features=3, bias=True)
)
(block_2): ImprovedIntermediateBlock(
(convs): ModuleList(
(0-2): 3 x Sequential(
(0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
(3): Dropout2d(p=0.2, inplace=False)
)
)
(fc): Linear(in_features=64, out_features=3, bias=True)
)
)
(output_block): EnhancedOutputBlock(
(avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
(fc1): Linear(in_features=64, out_features=32, bias=True)
(fc2): Linear(in_features=32, out_features=10, bias=True)
(dropout): Dropout(p=0.5, inplace=False)
)
)
Training Loop where error comes from:
model = CIFAR10Classifier().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
losses = []
def update_plot(epoch, loss):
losses.append(loss)
plt.plot(losses, '-x')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.pause(0.001)
num_epochs = 20
for epoch in range(1, num_epochs + 1):
start_time = time.time()
running_loss = 0.0
total_batches = 0
for i, (inputs, labels) in enumerate(train_loader, 0):
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
total_batches += 1
avg_loss = running_loss / total_batches
update_plot(epoch, avg_loss)
elapsed_time = time.time() - start_time
if epoch % 10 == 0 or epoch == 1:
print(f'Epoch {epoch}/{num_epochs} - Loss: {avg_loss:.4f} - Time: {elapsed_time:.2f}s')
plt.show()