Hi there,
is it possible to retrieve an intermediary result from a custom backward pass?
Basically in a user-defined autograd.Function
an intermediary value is calculated and then contracted (e.g. via sum
) to obtain the correct backwards gradient. Since the intermediary value is also needed for other calculations besides the gradient, it would be elegant to be able to store and reuse it.
import torch
class DummyFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, inp):
factor = torch.tensor([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]])
return factor * inp
@staticmethod
def backward(ctx, bkw):
def dummy(tensor):
# do sth.
return tensor
# how to retrieve this
intermediary = dummy(bkw)
res = torch.sum(intermediary, (-2, -1)).unsqueeze(-1)
return res
inp = torch.tensor([[1.0], [2.0]], requires_grad=True)
function = DummyFunction.apply
out = function(inp)
gradient = torch.autograd.grad(outputs=out.sum(), inputs=inp)[0]