Does list comprehension break the DAG for backprop?

Trying to combine several loss function in a dict called loss like so

err = torch.sum(torch.stack([l(input, label) for l in loss.values()]))

I am wondering if this breaks the connection for back prop

debugging tells me

<bound method Tensor.backward of tensor(17353.5000, device=‘cuda:0’, grad_fn=)>

which look ok but I am not certain if using an intermediate list followed by torch.stack succeeds in finalizing the computational graph

maybe someone knows the answer?

Your approach seems to work and I get valid gradients:

output = torch.randn(1, 1, requires_grad=True)
target = torch.randn(1, 1)

loss = {i: nn.MSELoss() for i in range(10)}
err = torch.sum(torch.stack([l(output, target) for l in loss.values()]))
err.backward()
print(output.grad)
> tensor([[-6.1927]])

thank you Patrick. That was my observation too.
However I was caught by surprise since there seems to be open issues regarding comprehension syntaxes cf. More robust list comprehension · Issue #48153 · pytorch/pytorch · GitHub

But maybe that only relates to jit…

This issue seems to be specific to the JIT, but let me know, if you see any issues using the eager approach (e.g. None gradients where you would expect valid gradients).

1 Like