About change of way to store tensors in GPU?

Dear All,

At the end of an intermediate layer in the forward pass, I want to store a modified version of the output of this layer instead of the original one. Then in the backward pass, I would like to use the modified version to compute the gradient for the next layer. Let’s assume that the operation of that intermediate layer (ReLU, conv. …) is unchanged.

My questions are

  • How to ask PyTorch to save the modified tensor and detach the original one from the graph?
  • If PyTorch allows to do so, how will the GPU allocation be affected? Will it be different by the amount of the difference between the original and modified tensors?

Help is much appreciated

I think you could create custom autograd.Functions as described here and store the tensors via ctx.save_for_backward and use them afterwards in the backward operation.
Since this new tensor would be stored the memory would depend on its size.

Dear @ptrblck,

Thank you so much for your reply.

Assume that in the forward pass I use the following codes to compute output and modify the output:

output = run_function(input)
output = modify(output)

where run_function is ReLU, conv., etc. Then, in the backward pass I retrieve the modified output by:

(output, ) = ctx.saved_tensors

Together with the grad_output, will the following code is sufficient to compute the gradient for the input?

torch.autograd.backward(output, grad_output)

If not, what should I do?

Thank you so much for your time.

No, you would have to implement the backward method manually as shown in the linked example.
E.g. for ReLU:

    def backward(ctx, grad_output):
        In the backward pass we receive a Tensor containing the gradient of the loss
        with respect to the output, and we need to compute the gradient of the loss
        with respect to the input.
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_input[input < 0] = 0
        return grad_input

where you could now change this computation and add your custom methods to it.

I got it. Thank you so much @ptrblck!