.backward()
is not supported for values other than scalar ones.
Eg:
>>> bs = 3
>>> n1 = torch.randn(bs, 3, 224, 224)
>>> net = model()
>>> f = net(n1)
>>> f.backward()
. . .
.
RuntimeError: grad can be implicitly created only for scalar outputs
But
>>> f[0, 0].backward(retain_graph=1)
works
so,
>>> for i in range(bs): f[i, 0].backward(retain_graph=1)
Is the only solution, right?