Can't figure out how to apply batch learning

I’m creating a CNN to classify images from the MNIST dataset. I successfully created a class that extends torch.nn.Module and a training function, however, i have to create a dataloader with batch_size = 1, because when increasing the size I get the following error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
/tmp/ipykernel_28/1235996931.py in <module>
      8 optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3, weight_decay = 1e-5)
      9 
---> 10 train(model, train_dataloader, criterion, optimizer, device, epochs = 100)

/tmp/ipykernel_28/3373300348.py in train(model, dataloader, criterion, optimizer, device, epochs)
     29 #             expected_output = one_hot(expected_output, device)
     30             img = img.to(device)
---> 31             output = model(img)
     32             l = criterion(output, expected_output)
     33             loss += l

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1188         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1189                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190             return forward_call(*input, **kwargs)
   1191         # Do not call functions when jit is used
   1192         full_backward_hooks, non_full_backward_hooks = [], []

/tmp/ipykernel_28/3291318451.py in forward(self, x)
     11 
     12     def forward(self, x):
---> 13         x = self.conv1(x)
     14         x = self.conv2(x)
     15         x = self.pool(x)

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1188         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1189                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190             return forward_call(*input, **kwargs)
   1191         # Do not call functions when jit is used
   1192         full_backward_hooks, non_full_backward_hooks = [], []

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/conv.py in forward(self, input)
    461 
    462     def forward(self, input: Tensor) -> Tensor:
--> 463         return self._conv_forward(input, self.weight, self.bias)
    464 
    465 class Conv3d(_ConvNd):

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/conv.py in _conv_forward(self, input, weight, bias)
    458                             _pair(0), self.dilation, self.groups)
    459         return F.conv2d(input, weight, bias, self.stride,
--> 460                         self.padding, self.dilation, self.groups)
    461 
    462     def forward(self, input: Tensor) -> Tensor:

RuntimeError: Given groups=1, weight of size [5, 1, 3, 3], expected input[1, 128, 28, 28] to have 1 channels, but got 128 channels instead

Here is the code. Please note that I have commented the expected_output = one_hot(expected_output, device) line in the train function in order to get to see what happens with the model. The one_hot function is rudimental but it works, with batch_size = 1. Making it work with bigger batches would need a for loop or a more elegant solution.

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms

train_dataset = MNIST('./data/train', download = True, train = True, transform = transforms.Compose([transforms.ToTensor()]))
test_dataset = MNIST('./data/test', download = True, train = False, transform = transforms.Compose([transforms.ToTensor()]))

class MNISTClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        # input size: [1, 28, 28]
        self.conv1 = nn.Conv2d(in_channels = 1, out_channels = 5, kernel_size = 3, padding = 1) # output size: [5, 28, 28]
        self.conv2 = nn.Conv2d(in_channels = 5, out_channels = 10, kernel_size = 3, padding = 1) # output size: [10, 28, 28]
        self.pool = nn.MaxPool2d(kernel_size = 2, stride = 1) # output size: [10, 27, 27]
        self.relu1 = nn.ReLU(True)
        self.fc = nn.Linear(10 * 27 * 27, 10)
        self.relu2 = nn.ReLU(True)
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.pool(x)
        x = self.relu1(x)
        x = self.fc(x.view(-1))
        x = self.relu2(x)
        return x
    
    def out_to_label(self, out):
        max_out, index = out[0], 0
        for i in range(1, len(out)):
            if out[i] > max_out:
                max_out, index = out[i], i
        return index

def one_hot(x, device):
    if x == 0:
        return torch.tensor([1,0,0,0,0,0,0,0,0,0], dtype=torch.float32, device = device)
    elif x == 1:
        return torch.tensor([0,1,0,0,0,0,0,0,0,0], dtype=torch.float32, device = device)
    elif x == 2:
        return torch.tensor([0,0,1,0,0,0,0,0,0,0], dtype=torch.float32, device = device)
    elif x == 3:
        return torch.tensor([0,0,0,1,0,0,0,0,0,0], dtype=torch.float32, device = device)
    elif x == 4:
        return torch.tensor([0,0,0,0,1,0,0,0,0,0], dtype=torch.float32, device = device)
    elif x == 5:
        return torch.tensor([0,0,0,0,0,1,0,0,0,0], dtype=torch.float32, device = device)
    elif x == 6:
        return torch.tensor([0,0,0,0,0,0,1,0,0,0], dtype=torch.float32, device = device)
    elif x == 7:
        return torch.tensor([0,0,0,0,0,0,0,1,0,0], dtype=torch.float32, device = device)
    elif x == 8:
        return torch.tensor([0,0,0,0,0,0,0,0,1,0], dtype=torch.float32, device = device)
    else:
        return torch.tensor([0,0,0,0,0,0,0,0,0,1], dtype=torch.float32, device = device)


def train(model, dataloader, criterion, optimizer, device, epochs = 1):
    for epoch in range(0, epochs):
        loss = 0
        for data in dataloader:
            img, expected_output = data
            img = img.view(img.size(0), 28, 28)
#             expected_output = one_hot(expected_output, device)
            img = img.to(device)
            output = model(img)
            l = criterion(output, expected_output)
            loss += l
            optimizer.zero_grad()
            l.backward()
            optimizer.step()
        print(f'Epoch: [{epoch + 1}/{epochs}] Loss: {loss}')


device = ''
if torch.cuda.is_available(): device = torch.device('cuda')
else: device = torch.device('cpu')

model = MNISTClassifier().to(device)
train_dataloader = DataLoader(train_dataset, batch_size = 128, shuffle = True)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3, weight_decay = 1e-5)

train(model, train_dataloader, criterion, optimizer, device, epochs = 100)

The error is raised since your view operation is wrong and creates a 3D tensor while you want to create a 4D input in the shape [batch_size, channels, height, width].
Use img = img.view(img.size(0), 1, 28, 28) or just remove the view operation as the img tensor is already in the right shape.
Also, remove the one_hot method as your target already contains class indices which are expected in a multi-class classification use case using nn.CrossEntropyLoss.
If you really want to create one-hot encoded targets use F.one_hot instead.

Once this is fixed you would have to also fix the view operation in your model to x = self.fc(x.view(x.size(0), -1)) in order to keep the batch size equal.

1 Like

This fixed the issue, thank you so much!