Say, I am training a network with two data types, A and B. The net always produces two outputs. But it the input is of type B, I only want to backpropagate loss1.
Here is some pseudo-code to illustrate the problem:
loss1 = Loss1(output1, target1) if input_is_of_type_B: loss2 = torch.tensor(0.0) else: loss2 = Loss2(output2, target2) loss = (loos1 + loss2)/2 loss.backward()
Implemented this way, the code does not throw any errors, but will it do what I want/is there a better way?