I’m interested in computing `y`

that depends on many inputs and doing backward on it, as in this snippet:

```
net = ... # This is an nn.Module in train mode with many parameters.
y = 0
zs = []
# `list_of_inputs` contains many `x`s
for x in list_of_inputs:
z = net(x) # Assume each z is a scalar for simplicity.
zs.append(z)
zs = torch.cat(zs)
# The log-sum-exp is exemplary.
y = torch.logsumexp(zs, 0)
y.backward()
```

This snippet could easily eat up VRAM as there will be `len(list_of_inputs)`

copies of network parameters attached to the computation graph before finally doing backward on `y`

.

My solution is to use the chain rule. Denote the model parameters with p. Then dy/dp = [sum over i] dy/dzi dzi/dp. Computing dy/dzi is cheap, since there are few parameters involved. After that, we may call backward on each zi, feeding it with dp/dzi. I conjecture the following snippet:

```
net = ... # The large model, in train mode.
zs = []
for x in list_of_inputs:
z = net(x) # NOTE #1
zs.append(z.detach())
zs = torch.cat(zs).requires_grad_()
y = torch.logsumexp(zs, 0)
y.backward()
zs_grad = torch.empty_like(zs).copy_(zs.grad)
for x, dz in zip(list_of_inputs, zs_grad):
z = net(x) # NOTE #2
# Accumulate gradient at the model parameters,
# freeing the attached network parameters.
z.backward(gradient=dz)
```

The snippet should work well, unless submodules in `net`

such as `BatchNorm`

or `Dropout`

exist. They will cause the `z`

computed at `NOTE #1`

to be different from `z`

at `NOTE #2`

, breaking the chain rule. Such a problem is exactly why this topic is queried. I speculate that the problem might be resolved by somehow “undoing” the model buffer state after computing dy/dz, which is where the title is from.

Are there any solutions to it? Alternative approaches to compute dy/dp without exploding VRAM are also welcome. Thank you so much!