Is it possible to keep intermediate results of forward for backward?

Hello,

I would like to implement a new torch.autograd.Function where the gradient is closely related to an intermediate result. Is there a way to store the intermediate result for the backward pass to avoid having to compute it again? (Similar to save_for_backward, but that explicitly isn’t it…)

Thank you

Thomas

2 Likes

Why not save_for_backward? it’s there for precisely this purpose.
If not you can also just assign the intermediate result Tensor to self self.intermediate = intermediate_result

2 Likes

Thanks!

For me save_for_backward complained if what I passed is not input or output.

RuntimeError: save_for_backward can only save input or output tensors, but argument 2 doesn't satisfy this condition

I’ll do the assigning to self.intermediate, I just wanted to check that it does not run into problems with advanced (e.g. parallel) use.

The autograd system is designed to handle intermediate results, but only if they were formulated using torch.autograd.Variable() and torch tensor functions, afaik. The error is because one of the arguments is not a Variable.

The question is bit outdated, I encountered the same situation

It seems that only arguments of forward method can be saved.
intermediate result (tensor) failed with message below.
am i right? is it okay to save tensor like self.intermediate = intermediate_result in terms of performance ?

def forward(self, a, b, c):
     # okay
     # self.save_for_backward(a) 
              
     # wrong!!
     self.save_for_backward(a+1)

RuntimeError: save_for_backward can only save input or output tensors, but argument 2 doesn’t satisfy this condition

My understanding is that it is safe to save things yourself by assigning to (harmlessly named) members of self (ctx for new style autograd).
In the new style you need to wrap it yourself to get a Variable.

Best regards

Thomas

2 Likes

Saving the intermediate calculations as members of self worked for me, thank you! :slight_smile:

I have the same question (I need to store info for backward).

I am following this example: http://pytorch.org/tutorials/beginner/examples_autograd/two_layer_net_custom_function.html,
Pytorch 0.2 and get the same RuntimeError

Unfortunately, I don’t see where this requirement (store only input or output) comes from and how do I get around it (storing as a member in Function means you can use this function only once).

Any ideas how to store intermediate things?

I see two main options.

  1. Using the autograd.Function instance only once is a great way to do this. In fact this is what used to happen all the time when you used operations on variables and is the right thing to do™.
  2. If you absolutely dislike autograd.Function instances or - like myself - like to feel modern, go for the new-style autograd.Function and change the line with the function application to y_pred = MyReLU.apply(x.mm(w1)).mm(w2). You do not need relu when doing this. This will store inputs and other stuff in the context ctx and is what happens nowadays when you use operations on variables.

Best regards

Thomas

class MyReLU(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        return input.clamp(min=0)

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_variables
        grad_input = grad_output.clone()
        grad_input[input < 0] = 0
        return grad_input

  1. Using the autograd.Function instance only once is a great way to do this. In fact this is what used to happen all the time when you used operations on variables and is the right thing to do™.

Ok, probably it is a way to go, where do I find an example with instantiating function?

  1. If you absolutely dislike autograd.Function instances or - like myself - like to feel modern, go for the new-style autograd.Function and change the line with the function application to y_pred = MyReLU.apply(x.mm(w1)).mm(w2). You do not need relu when doing this. This will store inputs and other stuff in the context ctx and is what happens nowadays when you use operations on variables.

Values that I need to store are actually integer numbers that occur during forward pass, so I can’t follow this recipe.

Found an example here: