Say, I am training a network with two data types, A and B. The net always produces two outputs. But it the input is of type B, I only want to backpropagate loss1.
Here is some pseudo-code to illustrate the problem:
loss1 = Loss1(output1, target1)
if input_is_of_type_B:
loss2 = torch.tensor(0.0)
else:
loss2 = Loss2(output2, target2)
loss = (loos1 + loss2)/2
loss.backward()
Implemented this way, the code does not throw any errors, but will it do what I want/is there a better way?
Possible duplicates:
• How to define a suitable "zero loss"?
• How to attach a cost to a graph