Trying to combine several loss function in a dict called loss like so
err = torch.sum(torch.stack([l(input, label) for l in loss.values()]))
I am wondering if this breaks the connection for back prop
debugging tells me
<bound method Tensor.backward of tensor(17353.5000, device=‘cuda:0’, grad_fn=)>
which look ok but I am not certain if using an intermediate list followed by torch.stack succeeds in finalizing the computational graph
maybe someone knows the answer?
Your approach seems to work and I get valid gradients:
output = torch.randn(1, 1, requires_grad=True)
target = torch.randn(1, 1)
loss = {i: nn.MSELoss() for i in range(10)}
err = torch.sum(torch.stack([l(output, target) for l in loss.values()]))
err.backward()
print(output.grad)
> tensor([[-6.1927]])
thank you Patrick. That was my observation too.
However I was caught by surprise since there seems to be open issues regarding comprehension syntaxes cf. More robust list comprehension · Issue #48153 · pytorch/pytorch · GitHub
But maybe that only relates to jit…
This issue seems to be specific to the JIT, but let me know, if you see any issues using the eager approach (e.g. None
gradients where you would expect valid gradients).
1 Like