I’m interested in computing y
that depends on many inputs and doing backward on it, as in this snippet:
net = ... # This is an nn.Module in train mode with many parameters.
y = 0
zs = []
# `list_of_inputs` contains many `x`s
for x in list_of_inputs:
z = net(x) # Assume each z is a scalar for simplicity.
zs.append(z)
zs = torch.cat(zs)
# The log-sum-exp is exemplary.
y = torch.logsumexp(zs, 0)
y.backward()
This snippet could easily eat up VRAM as there will be len(list_of_inputs)
copies of network parameters attached to the computation graph before finally doing backward on y
.
My solution is to use the chain rule. Denote the model parameters with p. Then dy/dp = [sum over i] dy/dzi dzi/dp. Computing dy/dzi is cheap, since there are few parameters involved. After that, we may call backward on each zi, feeding it with dp/dzi. I conjecture the following snippet:
net = ... # The large model, in train mode.
zs = []
for x in list_of_inputs:
z = net(x) # NOTE #1
zs.append(z.detach())
zs = torch.cat(zs).requires_grad_()
y = torch.logsumexp(zs, 0)
y.backward()
zs_grad = torch.empty_like(zs).copy_(zs.grad)
for x, dz in zip(list_of_inputs, zs_grad):
z = net(x) # NOTE #2
# Accumulate gradient at the model parameters,
# freeing the attached network parameters.
z.backward(gradient=dz)
The snippet should work well, unless submodules in net
such as BatchNorm
or Dropout
exist. They will cause the z
computed at NOTE #1
to be different from z
at NOTE #2
, breaking the chain rule. Such a problem is exactly why this topic is queried. I speculate that the problem might be resolved by somehow “undoing” the model buffer state after computing dy/dz, which is where the title is from.
Are there any solutions to it? Alternative approaches to compute dy/dp without exploding VRAM are also welcome. Thank you so much!