`.backward()`

is not supported for values other than scalar ones.

Eg:

```
>>> bs = 3
>>> n1 = torch.randn(bs, 3, 224, 224)
>>> net = model()
>>> f = net(n1)
>>> f.backward()
. . .
.
RuntimeError: grad can be implicitly created only for scalar outputs
```

But

```
>>> f[0, 0].backward(retain_graph=1)
```

works

so,

```
>>> for i in range(bs): f[i, 0].backward(retain_graph=1)
```

Is the only solution, right?