Say in the forward() pass, I generate the network predictions like the following
def forward(self, input):
conditions = [1, 2, 3, 4, 5]
out_dict = {}
for c in conditions:
output = self.network( (input, c) )
out_dict[index of c] = output
return out_dict
Now from the ground truth, we know that we only care about condition 3, and only want to backpropagate the loss in this condition. So I did
loss = loss_fn(out_dict[3])
loss.backward()
To verify this, I add a hook to the output of each condition. The code is shown below. If everything works perfectly, we should see β3 [gradient tensor β¦]β, but it is always β5 [gradient tensor β¦]β. So it seems like the conditional backpropagation is not working, what is going on here? and how should I do the backprop only with the tensor that I am interested in?
def forward(self, input):
conditions = [1, 2, 3, 4, 5]
out_dict = {}
for c in conditions:
output = self.network( (input, c) )
output.register_hook(lambda grad: print("{index of c}", grad))
out_dict[index of c] = output
return out_dict
The print statement might be buggy, if you are directly printing c in the lambda expression as explained here, so you might need to assign the value of c as a new argument to the lambda function via:
Ah! Thatβs it! Thank you for the solution.
I also create a small sample code to play around with this. As you can see, the lambda printouts the correct index only if I am using a local variable.
A weird thing is that when I try to pre-allocate a string message and pass it into the lambda function, the output is still 5.0. I dug into the internet and did not find out when is the string be evaluated, but I hypothesize that it is only evaluated when the __str()__ function is called. Therefore, it is keeping a reference to the cond, which has been set to 5.0 in the end (i.e. the same reason as the lambda function)
import torch
import torch.nn as nn
class MyNet(nn.Module):
def __init__(self):
super(MyNet, self).__init__()
in_features = 5
self.networks = nn.Sequential(
nn.Linear(in_features=in_features, out_features=100, bias=True),
nn.Linear(in_features=100, out_features=2),
)
self.initial_weights()
def forward(self, input):
num_conditions = 5
conditions = torch.arange(start=1, end=num_conditions + 1, dtype=torch.float)
out_features = {}
for i, cond in enumerate(conditions):
feature = torch.cat((input, cond.unsqueeze(0)), dim=0)
feature = self.networks(feature)
# message = "condition {}".format(cond)
feature.register_hook(lambda grad, c=cond: print(c))
out_features[i + 1] = feature
return out_features
def initial_weights(self):
for module in self.networks:
if isinstance(module, nn.Linear):
if module.weight is not None:
nn.init.xavier_uniform_(module.weight, gain=1.0)
if module.bias is not None:
nn.init.zeros_(module.bias)
torch.manual_seed(42)
network = MyNet()
gt_cond = 3
gt = torch.ones(2)
data = torch.rand(4)
out_features = network(data)
loss = nn.functional.mse_loss(gt, out_features[gt_cond])
loss.backward() # Expect to see 3.0 in the stdout