How to reset buffers of a module in train mode forward pass?

I’m interested in computing y that depends on many inputs and doing backward on it, as in this snippet:

net = ...  # This is an nn.Module in train mode with many parameters.
y = 0
zs = []
# `list_of_inputs` contains many `x`s
for x in list_of_inputs:
    z = net(x)  # Assume each z is a scalar for simplicity.
zs =
# The log-sum-exp is exemplary.
y = torch.logsumexp(zs, 0)

This snippet could easily eat up VRAM as there will be len(list_of_inputs) copies of network parameters attached to the computation graph before finally doing backward on y.

My solution is to use the chain rule. Denote the model parameters with p. Then dy/dp = [sum over i] dy/dzi dzi/dp. Computing dy/dzi is cheap, since there are few parameters involved. After that, we may call backward on each zi, feeding it with dp/dzi. I conjecture the following snippet:

net = ...  # The large model, in train mode.
zs = []
for x in list_of_inputs:
    z = net(x)  # NOTE #1
zs =
y = torch.logsumexp(zs, 0)
zs_grad = torch.empty_like(zs).copy_(zs.grad)
for x, dz in zip(list_of_inputs, zs_grad):
    z = net(x)  # NOTE #2
    # Accumulate gradient at the model parameters,
    # freeing the attached network parameters.

The snippet should work well, unless submodules in net such as BatchNorm or Dropout exist. They will cause the z computed at NOTE #1 to be different from z at NOTE #2, breaking the chain rule. Such a problem is exactly why this topic is queried. I speculate that the problem might be resolved by somehow “undoing” the model buffer state after computing dy/dz, which is where the title is from.

Are there any solutions to it? Alternative approaches to compute dy/dp without exploding VRAM are also welcome. Thank you so much!