`.backward()` for !scalar values is not supported. Is there some workaround other than a loop?

.backward() is not supported for values other than scalar ones.
Eg:

>>> bs = 3
>>> n1 = torch.randn(bs, 3, 224, 224)
>>> net = model()
>>> f = net(n1)
>>> f.backward()
. . .
.
RuntimeError: grad can be implicitly created only for scalar outputs

But

>>> f[0, 0].backward(retain_graph=1)

works
so,

>>> for i in range(bs): f[i, 0].backward(retain_graph=1)

Is the only solution, right?

f.sum().backward(retain_graph=1)

d(sum_k F(x_k))/dx_j = dF(x_j)/dx_j

1 Like

Yeah true! Right! Thanks!