RuntimeError with Channel Mismatch in PyTorch Classifier

Training a classifier using PyTorch results in a channel mismatch error during the forward pass of the model.
The RuntimeError is unusual because the input data shape correctly matches the expected [batch_size, 3, 32, 32] format for CIFAR-10 images, yet the model weirdly anticipates 64 input channels at some point.

Error Message:

RuntimeError: Given groups=1, weight of size [64, 64, 3, 3], expected input[64, 3, 32, 32] to have 64 channels, but got 3 channels instead.

Data Preparation and Loading:

import torch
from torch.utils.data import DataLoader, TensorDataset
import numpy as np

def unpickle(file):
    import pickle
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

train_data = []
train_labels = []
for i in range(1, 6):
    batch = unpickle(f'data_batch_{i}')
    train_data.append(batch[b'data'])
    train_labels += batch[b'labels']

train_data = np.vstack(train_data).reshape(-1, 3, 32, 32).transpose(0, 2, 3, 1)
train_labels = np.array(train_labels)


test_batch = unpickle('test_batch')
test_data = test_batch[b'data'].reshape(-1, 3, 32, 32).transpose(0, 2, 3, 1)
test_labels = np.array(test_batch[b'labels'])


train_data_tensor = torch.tensor(train_data.transpose((0, 3, 1, 2)), dtype=torch.float) / 255.0
train_labels_tensor = torch.tensor(train_labels, dtype=torch.long)
test_data_tensor = torch.tensor(test_data.transpose((0, 3, 1, 2)), dtype=torch.float) / 255.0
test_labels_tensor = torch.tensor(test_labels, dtype=torch.long)


train_loader = DataLoader(TensorDataset(train_data_tensor, train_labels_tensor), batch_size=64, shuffle=True)
test_loader = DataLoader(TensorDataset(test_data_tensor, test_labels_tensor), batch_size=64, shuffle=False)

The shape of the inputs in the first batch after the DataLoader is: torch.Size([64, 3, 32, 32]).​

Model Definition:

import torch.nn as nn
import torch.nn.functional as F

class ImprovedIntermediateBlock(nn.Module):
    def __init__(self, in_channels, out_channels, num_convs=3):
        super(ImprovedIntermediateBlock, self).__init__()
        self.convs = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, padding=1),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(),
                nn.Dropout2d(0.2)
            ) for i in range(num_convs)
        ])
        self.fc = nn.Linear(in_channels, num_convs)

    def forward(self, x):
        channel_means = x.mean([-2, -1])
        weights = F.softmax(self.fc(channel_means), dim=1)
        conv_outputs = torch.stack([conv(x) for conv in self.convs], dim=1)
        output = torch.sum(weights.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) * conv_outputs, dim=1)
        return output

class EnhancedOutputBlock(nn.Module):
    def __init__(self, in_channels, num_classes):
        super(EnhancedOutputBlock, self).__init__()
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc1 = nn.Linear(in_channels, in_channels // 2)
        self.fc2 = nn.Linear(in_channels // 2, num_classes)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

class CIFAR10Classifier(nn.Module):
    def __init__(self, num_blocks=2, num_classes=10):
        super(CIFAR10Classifier, self).__init__()
        self.blocks = nn.Sequential()
        in_channels = 3
        out_channels = 64
        for i in range(num_blocks):
            self.blocks.add_module(f"block_{i+1}", ImprovedIntermediateBlock(in_channels, out_channels))
            in_channels = out_channels
        self.output_block = EnhancedOutputBlock(out_channels, num_classes)

    def forward(self, x):
        x = self.blocks(x)
        logits = self.output_block(x)
        return logits

model = CIFAR10Classifier().to(device)

Model output:

CIFAR10Classifier(
  (blocks): Sequential(
    (block_1): ImprovedIntermediateBlock(
      (convs): ModuleList(
        (0): Sequential(
          (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU()
          (3): Dropout2d(p=0.2, inplace=False)
        )
        (1-2): 2 x Sequential(
          (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU()
          (3): Dropout2d(p=0.2, inplace=False)
        )
      )
      (fc): Linear(in_features=3, out_features=3, bias=True)
    )
    (block_2): ImprovedIntermediateBlock(
      (convs): ModuleList(
        (0-2): 3 x Sequential(
          (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU()
          (3): Dropout2d(p=0.2, inplace=False)
        )
      )
      (fc): Linear(in_features=64, out_features=3, bias=True)
    )
  )
  (output_block): EnhancedOutputBlock(
    (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
    (fc1): Linear(in_features=64, out_features=32, bias=True)
    (fc2): Linear(in_features=32, out_features=10, bias=True)
    (dropout): Dropout(p=0.5, inplace=False)
  )
)

Training Loop where error comes from:

model = CIFAR10Classifier().to(device)

optimizer = optim.Adam(model.parameters(), lr=0.001)  
criterion = nn.CrossEntropyLoss()


losses = []
def update_plot(epoch, loss):
    losses.append(loss)
    plt.plot(losses, '-x')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training Loss')
    plt.pause(0.001)

num_epochs = 20

for epoch in range(1, num_epochs + 1):
    start_time = time.time()
    running_loss = 0.0
    total_batches = 0
    
    for i, (inputs, labels) in enumerate(train_loader, 0):
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        total_batches += 1

    avg_loss = running_loss / total_batches
    update_plot(epoch, avg_loss) 
    elapsed_time = time.time() - start_time
    
    if epoch % 10 == 0 or epoch == 1:
        print(f'Epoch {epoch}/{num_epochs} - Loss: {avg_loss:.4f} - Time: {elapsed_time:.2f}s')

plt.show()

The issue is raised in:

conv_outputs = torch.stack([conv(x) for conv in self.convs], dim=1)

since self.convs consists of:

ModuleList(
  (0): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout2d(p=0.2, inplace=False)
  )
  (1-2): 2 x Sequential(
    (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout2d(p=0.2, inplace=False)
  )
)

and is applied to the same x input in the shape [64, 3, 32, 32].
self.convs[0] will work while self.convs[1] and self.convs[2] will fail (the former will raise the error of course first).

Thank you was able to identify and resolve the convolutional layer application issue with this