Loss.backward() affect other variables?

class network(nn.Module):
    def __init__(self):
        super(network, self).__init__()
        self.linear1 = nn.Linear(in_features=40, out_features=50)
        self.bn1 = nn.BatchNorm1d(num_features=50)
        self.linear2 = nn.Linear(in_features=50, out_features=25)
        self.bn2 = nn.BatchNorm1d(num_features=25)
        self.linear3 = nn.Linear(in_features=25, out_features=1)

    def forward(self, input):  # Input is a 1D tensor
        out_list = []
        y = F.relu(self.bn1(self.linear1(input)))
        out_list.append(y)
        y = F.relu(self.bn2(self.linear2(y)))
        out_list.append(y)
        y = self.linear3(y)
        out_list.append(y)
        return out_list
    
model = network()
x = torch.randn(10, 40)
output_list = model(x)

Let’s say that I have the target for each one of the outputs in the output_list.
Each one of the output will have a loss associated. Can I run loss.backward() for each one without update variables that didn’t contribute to the loss? Like, the loss of the first output of the list shouldn’t change the weights of the last linear.

If I sum all the loss and run total_loss.backward(), the loss of the first linear will affect the update of the last linear? I want to total_loss.backward() doesn’t affect other variables without contribution to ‘loss’.

Yes, this should be the case as shown by the following code snippet using your example code:

model = network()
x = torch.randn(10, 40)
output_list = model(x)

output_list[0].mean().backward(retain_graph=True)
for name, param in model.named_parameters():
    print(name, param.grad)

output_list[1].mean().backward(retain_graph=True)
for name, param in model.named_parameters():
    print(name, param.grad)

output_list[2].mean().backward()
for name, param in model.named_parameters():
    print(name, param.grad)
1 Like