Conditionally backpropagating the loss

Say in the forward() pass, I generate the network predictions like the following

def forward(self, input):
    conditions = [1, 2, 3, 4, 5]
    
    out_dict = {}
    for c in conditions:
        output = self.network( (input, c) ) 
        out_dict[index of c] = output

    return out_dict

Now from the ground truth, we know that we only care about condition 3, and only want to backpropagate the loss in this condition. So I did

loss = loss_fn(out_dict[3])
loss.backward()

To verify this, I add a hook to the output of each condition. The code is shown below. If everything works perfectly, we should see β€œ3 [gradient tensor …]”, but it is always β€œ5 [gradient tensor …]”. So it seems like the conditional backpropagation is not working, what is going on here? and how should I do the backprop only with the tensor that I am interested in?

def forward(self, input):
    conditions = [1, 2, 3, 4, 5]
    
    out_dict = {}
    for c in conditions:
        output = self.network( (input, c) ) 
       
        output.register_hook(lambda grad: print("{index of c}", grad))

        out_dict[index of c] = output

    return out_dict

The print statement might be buggy, if you are directly printing c in the lambda expression as explained here, so you might need to assign the value of c as a new argument to the lambda function via:

lambda grad, c=c: print(c, grad)
1 Like

Ah! That’s it! Thank you for the solution.
I also create a small sample code to play around with this. As you can see, the lambda printouts the correct index only if I am using a local variable.

A weird thing is that when I try to pre-allocate a string message and pass it into the lambda function, the output is still 5.0. I dug into the internet and did not find out when is the string be evaluated, but I hypothesize that it is only evaluated when the __str()__ function is called. Therefore, it is keeping a reference to the cond, which has been set to 5.0 in the end (i.e. the same reason as the lambda function)

import torch
import torch.nn as nn


class MyNet(nn.Module):
    def __init__(self):
        super(MyNet, self).__init__()
        in_features = 5
        self.networks = nn.Sequential(
            nn.Linear(in_features=in_features, out_features=100, bias=True),
            nn.Linear(in_features=100, out_features=2),
        )
        self.initial_weights()

    def forward(self, input):
        num_conditions = 5

        conditions = torch.arange(start=1, end=num_conditions + 1, dtype=torch.float)
        out_features = {}
        for i, cond in enumerate(conditions):
            feature = torch.cat((input, cond.unsqueeze(0)), dim=0)
            feature = self.networks(feature)

            # message = "condition {}".format(cond)
            feature.register_hook(lambda grad, c=cond: print(c))
            out_features[i + 1] = feature
        return out_features

    def initial_weights(self):
        for module in self.networks:
            if isinstance(module, nn.Linear):
                if module.weight is not None:
                    nn.init.xavier_uniform_(module.weight, gain=1.0)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)


torch.manual_seed(42)
network = MyNet()
gt_cond = 3
gt = torch.ones(2)
data = torch.rand(4)
out_features = network(data)

loss = nn.functional.mse_loss(gt, out_features[gt_cond])
loss.backward()  # Expect to see 3.0 in the stdout