Backward prop on a subset of layers in a multi-loss model

Hello everyone,

I am building a network for image recognition, which is trained to recognize and identify handwritten digits from the MNIST dataset in a multi-class classification task, and at the same time the model compares two freshly recognized digits, outputting a 0 if the first digit is inferior to the second, or 1 in the contrary scenario.

For these two tasks, I use the NLLLoss for the multi class problem, and a BCELoss for the binary classification problem. The first output is the multi class and the second is the comparison. Here is my model, working with mini-batches of size 25, input size is 2,14,14 (the two channels refer to the pairs needed to be compared):

class Net_convo(nn.Module):
    
    def __init__(self):
        super().__init__()
        
        self.convlayer1 = nn.Sequential(
            nn.Conv2d(2, 32, 3),
            nn.BatchNorm2d(32),
            nn.ReLU())
        
        self.convlayer2 = nn.Sequential(
            nn.Conv2d(32, 64, 3),
            nn.BatchNorm2d(64),
            nn.ReLU())
        
        self.dropout = nn.Dropout(0.2)
        self.lin1 = nn.Linear(3200, 10)
        self.lin2 = nn.Linear(10, 1)
        self.lin3 = nn.Linear(2, 1)
        
        self.logsoft = nn.LogSoftmax(dim=2)

    def forward(self, x):

        # Input size is (25, 2, 14, 14)
        x = self.convlayer1(x)
        x = self.convlayer2(x)
        
        x = x.view(25, 2, -1)
        
        # (25, 2, 3200)
        x = F.relu(self.lin1(x))
        # (25, 2, 10)
        output1 = self.logsoft(x)
        
        # (25, 2, 10)
        x = F.relu(self.lin2(x))
        
        # (25, 2, 1)
        x = torch.flatten(x, 1)
        
        # (25, 2)
        output2 = torch.sigmoid(self.lin3(x))
        # (25, 1)
        return output1, output2

The ouputs are fed into two losses, with two separate targets for each tasks. I am getting good result on the multi class problem, but I consistently find that the two performances are tightly linked, if I get a certain performance for the first task, then the second task will have an attached performance, meaning that the two losses are interfering with the rest of the model.

output1, output2 = model(train_input.narrow(0, b, mini_batch_size))
loss1 = criterion1(output1.view(-1, 10), train_classes.narrow(0, b, mini_batch_size).view(-1))
loss2 = criterion2(output2.view(-1), train_target.narrow(0, b, mini_batch_size).to(torch.float32))
loss = loss1 + loss2
loss.backward()
optimizer.step()

What I need to do is have the first loss for the multi class problem only backprop on the convolutional layers and the first linear layer, while the second loss only backprop on the last linear layers after the first output. Is there a way to achieve that ? Perhaps using the requires_grad argument…

Thank you so much.

Your forward method looks generally alright, but to make sure output2 (and its corresponding loss) only calculates gradients for the last two linear layers, you would have to detach the input to these layers:

    def forward(self, x):

        # Input size is (25, 2, 14, 14)
        x = self.convlayer1(x)
        x = self.convlayer2(x)
        
        x = x.view(25, 2, -1)
        
        # (25, 2, 3200)
        x = F.relu(self.lin1(x))
        # (25, 2, 10)
        output1 = self.logsoft(x)
        
        # (25, 2, 10)
        x = F.relu(self.lin2(x.detach())) # add detach here!
        
        # (25, 2, 1)
        x = torch.flatten(x, 1)
        
        # (25, 2)
        output2 = torch.sigmoid(self.lin3(x))
        # (25, 1)
        return output1, output2

This will make sure that the backward pass originating from output2 will be stopped at this point and won’t calculate any gradients for the previous layers.