# Conditionally backpropagating the loss

Say in the forward() pass, I generate the network predictions like the following

``````def forward(self, input):
conditions = [1, 2, 3, 4, 5]

out_dict = {}
for c in conditions:
output = self.network( (input, c) )
out_dict[index of c] = output

return out_dict
``````

Now from the ground truth, we know that we only care about condition 3, and only want to backpropagate the loss in this condition. So I did

``````loss = loss_fn(out_dict[3])
loss.backward()
``````

To verify this, I add a hook to the output of each condition. The code is shown below. If everything works perfectly, we should see β3 [gradient tensor β¦]β, but it is always β5 [gradient tensor β¦]β. So it seems like the conditional backpropagation is not working, what is going on here? and how should I do the backprop only with the tensor that I am interested in?

``````def forward(self, input):
conditions = [1, 2, 3, 4, 5]

out_dict = {}
for c in conditions:
output = self.network( (input, c) )

out_dict[index of c] = output

return out_dict

``````

The print statement might be buggy, if you are directly printing `c` in the lambda expression as explained here, so you might need to assign the value of `c` as a new argument to the `lambda` function via:

``````lambda grad, c=c: print(c, grad)
``````
1 Like

Ah! Thatβs it! Thank you for the solution.
I also create a small sample code to play around with this. As you can see, the lambda printouts the correct index only if I am using a local variable.

A weird thing is that when I try to pre-allocate a string `message` and pass it into the lambda function, the output is still 5.0. I dug into the internet and did not find out when is the string be evaluated, but I hypothesize that it is only evaluated when the `__str()__` function is called. Therefore, it is keeping a reference to the `cond`, which has been set to 5.0 in the end (i.e. the same reason as the lambda function)

``````import torch
import torch.nn as nn

class MyNet(nn.Module):
def __init__(self):
super(MyNet, self).__init__()
in_features = 5
self.networks = nn.Sequential(
nn.Linear(in_features=in_features, out_features=100, bias=True),
nn.Linear(in_features=100, out_features=2),
)
self.initial_weights()

def forward(self, input):
num_conditions = 5

conditions = torch.arange(start=1, end=num_conditions + 1, dtype=torch.float)
out_features = {}
for i, cond in enumerate(conditions):
feature = torch.cat((input, cond.unsqueeze(0)), dim=0)
feature = self.networks(feature)

# message = "condition {}".format(cond)
out_features[i + 1] = feature
return out_features

def initial_weights(self):
for module in self.networks:
if isinstance(module, nn.Linear):
if module.weight is not None:
nn.init.xavier_uniform_(module.weight, gain=1.0)
if module.bias is not None:
nn.init.zeros_(module.bias)

torch.manual_seed(42)
network = MyNet()
gt_cond = 3
gt = torch.ones(2)
data = torch.rand(4)
out_features = network(data)

loss = nn.functional.mse_loss(gt, out_features[gt_cond])
loss.backward()  # Expect to see 3.0 in the stdout
``````