CNN with multiple outputs and batch processing

How exactly are batches processed in one iteration? For example, I have built a network that accepts an image and outputs 8 sets of values, having 36 probability distributed values each for each item of the set. So the total number of outputs for a single image is 8x36 = 288. So, my model actually returns 8 tensor values (x1,x2,…x8) at the end of the forward function having a size of [1,36]. Now, since I am working with batches, I am getting confused with the actual processing of each image in each iteration.

Suppose the batch size is set to 10. Now, the network will be given a batch of 10 images at once. Now, will the network process all the 10 images at once or does it take 1 image at a time, produce the result and then take another one and so on until all ten images are processed and then calculate the value and returns them? If it takes 8 images at once and processes them all at once, then the network is supposed to give me (batches x no of sets of values x no of probability distributed values in each set) = 10 x 8 x 36 = 2880 values, but as the network is designed to give 8 tensors of [1,36] size, how am I losing other values in the process? For reference, I am sharing the network architecture as well as the training loops for the same. Also, I am attaching the architecture that I am trying to mimic.

Network Architecture:

class Net(nn.Module):

def __init__(self):
    super(Net, self).__init__()
    self.conv1 = nn.Conv2d(3, 16, 3)
    self.conv1a = nn.Conv2d(16, 16, 3)
    self.batch1 = nn.BatchNorm2d(16)
    self.pool1 = nn.MaxPool2d(2, 2) 

    self.conv2 = nn.Conv2d(16, 32, 3)
    self.conv2a = nn.Conv2d(32, 32, 3)
    self.batch2 = nn.BatchNorm2d(32)
    self.pool2 = nn.MaxPool2d(2, 2)

    self.conv3 = nn.Conv2d(32, 64, 3)
    self.conv3a = nn.Conv2d(64, 64, 3)
    self.batch3 = nn.BatchNorm2d(64)
    self.pool3 = nn.MaxPool2d(2, 2)

    self.fc1 = nn.Linear(84672, 128)
    self.fc2 = nn.Linear(128, 36)
    self.softmax = nn.Softmax(dim=1)

def forward(self, x):
    x = F.relu(self.batch1(self.conv1(x)))
    x = F.relu(self.batch1(self.conv1a(x)))
    x = self.pool1(x)   

    x = F.relu(self.batch2(self.conv2(x)))
    x = F.relu(self.batch2(self.conv2a(x)))
    x = self.pool2(x)

    x = F.relu(self.batch3(self.conv3(x)))
    x = F.relu(self.batch3(self.conv3a(x)))
    x = self.pool3(x)

    x = x.view(-1,84672)

    x1 = F.relu(self.fc1(x))
    x1 = F.relu(self.fc2(x1)) 
    x1 = self.softmax(x1)

    x2 = F.relu(self.fc1(x))
    x2 = F.relu(self.fc2(x2)) 
    x2 = self.softmax(x2)

    x3 = F.relu(self.fc1(x))
    x3 = F.relu(self.fc2(x3)) 
    x3 = self.softmax(x3)

    x4 = F.relu(self.fc1(x))
    x4 = F.relu(self.fc2(x4)) 
    x4 = self.softmax(x4)

    x5 = F.relu(self.fc1(x))
    x5 = F.relu(self.fc2(x5)) 
    x5 = self.softmax(x5)

    x6 = F.relu(self.fc1(x))
    x6 = F.relu(self.fc2(x6)) 
    x6 = self.softmax(x6)

    x7 = F.relu(self.fc1(x))
    x7 = F.relu(self.fc2(x7)) 
    x7 = self.softmax(x7)

    x8 = F.relu(self.fc1(x))
    x8 = F.relu(self.fc2(x8)) 
    x8 = self.softmax(x8)
    return x1,x2,x3,x4,x5,x6,x7,x8

Training iteration:

steps = 0
print_every = 50

for e in range(30):
    running_loss = 0
    for batch_i, data in enumerate(train_loader):

        steps += 1   # Forward pass

        images = images.to(device)
        labels = labels.view(labels.size(0), -1)
        labels = labels.to(device)

        #label1,label2,label3,label4,label5,label6,label7,label8 = labels  <--------

        optimizer.zero_grad()
        x1,x2,x3,x4,x5,x6,x7,x8 = net(images)

        x1 = criterion(x1, labels1)
        x2 = criterion(x2, labels2)
        x3 = criterion(x3, labels3)
        x4 = criterion(x4, labels4)
        x5 = criterion(x5, labels5)
        x6 = criterion(x6, labels6)
        x7 = criterion(x7, labels7)
        x8 = criterion(x8, labels8)

        loss = x1 + x2 + x3 + x4 + x5 + x6 + x7 + x8
        loss.backward()   # Backward pass
        optimizer.step()

        running_loss += loss.item()
        if steps % print_every == 0:
            test_loss = 0
            accuracy = 0

            with torch.no_grad():
                model.eval()
                for images, labels in enumerate(test_loader):
                    images = data['image'].to(device)
                    labels = data['lpno'].to(device)
                    images = images.to(device)
                    labels = labels.view(labels.size(0), -1)
                    labels = labels.to(device)

                    log_ps = net(images)
                    test_loss += criterion(log_ps, labels)
                    ps = torch.exp(log_ps)

                    top_p, top_class = ps.topk(1, dim = 1)
                    equals = top_class == labels.view(*top_class.shape)
                    accuracy += torch.mean(equals.type(torch.FloatTensor))

            model.train()

            trainLoss.append(running_loss/len(train_loader))
            testLoss.append(test_loss/len(test_loader))

            print("Epoch: {}/{}.. ".format(e + 1, epochs),
                  "Test Accuracy: {:.3f}".format(accuracy/len(test_loader)))

For the calculation of the loss, I need to get labels1 to labels8 values but that could be only possible if I somehow know how batches of input get processed.

Architecture that I want to mimic:

All samples in a batch will be processed at once and most of the layers are using dim0 as the batch dimension.
Your input would thus be [batch_size, channels, height, width] and each output would also have the batch dimension in dim0. Note that recurrent layers such as nn.LSTM have a different default shape, but that’s not relevant for your use case.

If you are using a DataLoader, both the data and target should be returned as a complete batch.
Let me know, if that helps or if you got stuck somewhere.

My forward method looks like this:

As you can see, it needs to return 8 values (x1…x8) at the end of one iteration. But each of these 8 values is just [1,36] sized tensors. As you said, it should return batchsize * (x1…x8) number of tensors, right? But that’s not how it’s coming out.

I would recommend to replace

x = x.view(-1, 84672)

with

x = x.view(x.size(0), -1)

to keep the batch size as it is.
Could you change it and check the output shapes again?

Voila! It worked. With minors tweaks in my code, I passed through this error but as you can see, I need to average the loss before I can implement backpropagation, but I encounter the following error

        loss = x1 + x2 + x3 + x4 + x5 + x6 + x7 + x8
        loss = loss.mean()
        print(loss.dtype)
        loss.backward()

Output:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
 in 
     32         loss = loss.mean()
     33         print(loss.dtype)
---> 34         loss.backward()   # Backward pass
     35         optimizer.step()
     36 

~\Anaconda3\lib\site-packages\torch\tensor.py in backward(self, gradient, retain_graph, create_graph)
    193                 products. Defaults to ``False``.
    194         """
--> 195         torch.autograd.backward(self, gradient, retain_graph, create_graph)
    196 
    197     def register_hook(self, hook):

~\Anaconda3\lib\site-packages\torch\autograd\__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)
     97     Variable._execution_engine.run_backward(
     98         tensors, grad_tensors, retain_graph, create_graph,
---> 99         allow_unreachable=True)  # allow_unreachable flag
    100 
    101 

RuntimeError: expected dtype Float but got dtype Long

As you can see, the dtype of loss is float but I am still getting this error

I cannot reproduce this issue using your code:


class Net(nn.Module):
    
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, 3)
        self.conv1a = nn.Conv2d(16, 16, 3)
        self.batch1 = nn.BatchNorm2d(16)
        self.pool1 = nn.MaxPool2d(2, 2) 
    
        self.conv2 = nn.Conv2d(16, 32, 3)
        self.conv2a = nn.Conv2d(32, 32, 3)
        self.batch2 = nn.BatchNorm2d(32)
        self.pool2 = nn.MaxPool2d(2, 2)
    
        self.conv3 = nn.Conv2d(32, 64, 3)
        self.conv3a = nn.Conv2d(64, 64, 3)
        self.batch3 = nn.BatchNorm2d(64)
        self.pool3 = nn.MaxPool2d(2, 2)
    
        self.fc1 = nn.Linear(512, 128)
        self.fc2 = nn.Linear(128, 36)
        self.softmax = nn.Softmax(dim=1)
    
    def forward(self, x):
        x = F.relu(self.batch1(self.conv1(x)))
        x = F.relu(self.batch1(self.conv1a(x)))
        x = self.pool1(x)   
    
        x = F.relu(self.batch2(self.conv2(x)))
        x = F.relu(self.batch2(self.conv2a(x)))
        x = self.pool2(x)
    
        x = F.relu(self.batch3(self.conv3(x)))
        x = F.relu(self.batch3(self.conv3a(x)))
        x = self.pool3(x)
    
        x = x.view(-1,512)
    
        x1 = F.relu(self.fc1(x))
        x1 = F.relu(self.fc2(x1)) 
        x1 = self.softmax(x1)
    
        x2 = F.relu(self.fc1(x))
        x2 = F.relu(self.fc2(x2)) 
        x2 = self.softmax(x2)
    
        x3 = F.relu(self.fc1(x))
        x3 = F.relu(self.fc2(x3)) 
        x3 = self.softmax(x3)
    
        x4 = F.relu(self.fc1(x))
        x4 = F.relu(self.fc2(x4)) 
        x4 = self.softmax(x4)
    
        x5 = F.relu(self.fc1(x))
        x5 = F.relu(self.fc2(x5)) 
        x5 = self.softmax(x5)
    
        x6 = F.relu(self.fc1(x))
        x6 = F.relu(self.fc2(x6)) 
        x6 = self.softmax(x6)
    
        x7 = F.relu(self.fc1(x))
        x7 = F.relu(self.fc2(x7)) 
        x7 = self.softmax(x7)
    
        x8 = F.relu(self.fc1(x))
        x8 = F.relu(self.fc2(x8)) 
        x8 = self.softmax(x8)
        return x1,x2,x3,x4,x5,x6,x7,x8
    
model = Net()

x = torch.randn(2, 3, 50, 50)
out1, out2, out3, out4, out5, out6, out7, out8 = model(x)
loss = out1 + out2 + out3 + out4 + out5 + out6 + out7 + out8
loss = loss.mean()
loss.backward()

Note that I reduced the number of features a bit to debug it quickly.
Could you compare my code to yours and try to narrow down the difference?

Actually, the implementation is:

        x1,x2,x3,x4,x5,x6,x7,x8 = net(images)

        x1 = criterion(x1, label1)
        x2 = criterion(x2, label2)
        x3 = criterion(x3, label3)

        x5 = criterion(x5, label5)
        x6 = criterion(x6, label6)
        x7 = criterion(x7, label7)
        x8 = criterion(x8, label8)
       
        loss = x1 + x2 + x3 + x5 + x6 + x7 + x8
        loss = loss.mean()
        loss.backward()
        optimizer.step()

The code still works:

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, 3)
        self.conv1a = nn.Conv2d(16, 16, 3)
        self.batch1 = nn.BatchNorm2d(16)
        self.pool1 = nn.MaxPool2d(2, 2) 
    
        self.conv2 = nn.Conv2d(16, 32, 3)
        self.conv2a = nn.Conv2d(32, 32, 3)
        self.batch2 = nn.BatchNorm2d(32)
        self.pool2 = nn.MaxPool2d(2, 2)
    
        self.conv3 = nn.Conv2d(32, 64, 3)
        self.conv3a = nn.Conv2d(64, 64, 3)
        self.batch3 = nn.BatchNorm2d(64)
        self.pool3 = nn.MaxPool2d(2, 2)
    
        self.fc1 = nn.Linear(256, 128)
        self.fc2 = nn.Linear(128, 36)
        self.softmax = nn.LogSoftmax(dim=1)
    
    def forward(self, x):
        x = F.relu(self.batch1(self.conv1(x)))
        x = F.relu(self.batch1(self.conv1a(x)))
        x = self.pool1(x)   
    
        x = F.relu(self.batch2(self.conv2(x)))
        x = F.relu(self.batch2(self.conv2a(x)))
        x = self.pool2(x)
    
        x = F.relu(self.batch3(self.conv3(x)))
        x = F.relu(self.batch3(self.conv3a(x)))
        x = self.pool3(x)
    
        x = x.view(x.size(0),-1)
    
        x1 = F.relu(self.fc1(x))
        x1 = F.relu(self.fc2(x1)) 
        x1 = self.softmax(x1)
    
        x2 = F.relu(self.fc1(x))
        x2 = F.relu(self.fc2(x2)) 
        x2 = self.softmax(x2)
    
        x3 = F.relu(self.fc1(x))
        x3 = F.relu(self.fc2(x3)) 
        x3 = self.softmax(x3)
    
        x4 = F.relu(self.fc1(x))
        x4 = F.relu(self.fc2(x4)) 
        x4 = self.softmax(x4)
    
        x5 = F.relu(self.fc1(x))
        x5 = F.relu(self.fc2(x5)) 
        x5 = self.softmax(x5)
    
        x6 = F.relu(self.fc1(x))
        x6 = F.relu(self.fc2(x6)) 
        x6 = self.softmax(x6)
    
        x7 = F.relu(self.fc1(x))
        x7 = F.relu(self.fc2(x7)) 
        x7 = self.softmax(x7)
    
        x8 = F.relu(self.fc1(x))
        x8 = F.relu(self.fc2(x8)) 
        x8 = self.softmax(x8)
        return x1,x2,x3,x4,x5,x6,x7,x8
    
model = Net()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

x = torch.randn(2, 3, 50, 50)
x1, x2, x3, x4, x5, x6, x7, x8 = model(x)

criterion = nn.NLLLoss()
label1 = torch.randint(0, 36, (2,))
label2 = torch.randint(0, 36, (2,))
label3 = torch.randint(0, 36, (2,))
label4 = torch.randint(0, 36, (2,))
label5 = torch.randint(0, 36, (2,))
label6 = torch.randint(0, 36, (2,))
label7 = torch.randint(0, 36, (2,))
label8 = torch.randint(0, 36, (2,))

x1 = criterion(x1, label1)
x2 = criterion(x2, label2)
x3 = criterion(x3, label3)

x5 = criterion(x5, label5)
x6 = criterion(x6, label6)
x7 = criterion(x7, label7)
x8 = criterion(x8, label8)
   
loss = x1 + x2 + x3 + x5 + x6 + x7 + x8
loss = loss.mean()
loss.backward()
optimizer.step()

Note, that I changed nn.Softmax to nn.LogSoftmax, since nn.NLLLoss expects log probabilities.

This seemed to work only with L1Loss as criterion. I am wondering why is this the case?

Also, for printing the test accuracy, after the if condition, what changes should I incorporate in order to make it compatible with multiple outputs that my network is producing? The inner loop of testing seems to break at test_loss = criterion(log_ps, labels)

This shouldn’t be the case, as I’m using nn.NLLLoss and it works on my machine.

This is expected, as your model returns 8 outputs. I don’t know, what these outputs represent, but I would assume your loss calculation should be the same as during training.

That’s the strangest bit. It’s not working with my code. And another strange thing is, I am explicitly converting the tensors to the expected dtype (and printing the dtype for confirmation ) and yet the criterion is throwing the same error.

Try to compare my code snippet to yours and check each data type of the tensors.
If you don’t find the issue, please post an executable code snippet to reproduce this issue.