Output a gradient to a user defined tensor

I see torch.autograd.grad() can calculate the gradient for input and output, and returns the gradient. If I understand the code correctly, the returned gradient tensor is allocated while performing the computation. I wonder if it is possible to ask torch.autograd.grad() to output the results to a user predefined tensor.

or, I wonder if there is any other way to allow parameter.grad pre-allocated before the backward() operation.

My use case might be different from the common cases. I hacked into Pytorch’s memory management code to allow an allocated tensor to change its storage memory to a user-defined GPU memory. I also want to apply this on the gradient tensor. However, as gradient tensor is allocated during the backward() runtime, this makes my case very hard to realized on the gradient tensor as I don’t want to have an additional memory copy.

Hi,

When doing .backward(), if the leaf already has a .grad field, then, we try to re-use it to accumulate the new gradients.
When doing .grad(), we don’t have any API to do that because the gradients Tensor that we return is just the one that was outputed by the last backward function. And none of our backward functions can write into a Tensor inplace. So the only thing we can do is:

def inplace_grad(out, inp, grad_out, grad_in_buffers, ...):
    grad = autograd.grad(out, inp, grad_out, ...)
    for buf, g in zip(grad_in_buffers, grad):
        buf.copy_(g)

But even with changing the core, we can’t do better than this I’m afraid :confused:

1 Like

Thank you for the rapid reply. Since .backward() can accumulate over the new gradient, I wonder if the param.grad can be allocated before calling .backward(), or before executing the execution_engine in the torch/autograd/init.py or in torch/csrc/autograd/engine.cpp (func: THPEngine_run_backward)

If you use .backward(), then you can simply do that by setting the .grad field of your parameters before calling the .backward() function. No need to change anything else.

I see. So we can assign a gradient tensor to .grad (an example is shown below) to avoid another gradient torch created during .backward(). Am I right?

The example below is adapted from another answer: How to split backward process wrt each layer of neural network?
So in the example below, I can calculate the gradient manually and has a pre-defined gradient tensor for every layer (see the backward() function)?

I greatly appreciate your help!

import torch.nn as nn
from torch.autograd import Variable

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.layers = nn.ModuleList([
            nn.Linear(10, 10),
            nn.Linear(10, 10),
            nn.Linear(10, 10),
            nn.Linear(10, 10),
        ])

    def forward(self, x):
        self.output = []
        self.input = []
        for layer in self.layers:
            # detach from previous history
            x = Variable(x.data, requires_grad=True)
            self.input.append(x)

            # compute output
            x = layer(x)

            # add to list of outputs
            self.output.append(x)
        return x

    def backward(self, g):
        for i, output in reversed(list(enumerate(self.output))):
            a = torch.ones(4, 10)
            self.input[i].grad = a 
            if i == (len(self.output) - 1):
                # for last node, use g
                output.backward(g)
                print(self.input[i].grad.shape)
            else:
               output.backward(self.input[i+1].grad.data)


model = Net()
inp = Variable(torch.randn(4, 10))
output = model(inp)
gradients = torch.randn(*output.size())
model.backward(gradients)

Hi,

Note that to detach a Tensor, you should not do: x = Variable(x.data, requires_grad=True) but x = x.detach().requires_grad_(). (.data has other side effects and should not be used anymore).

Same for the backward, you don’t need the .data.

You can pre-allocate the gradients like this, yes. But in the code sample you gave, it has very limited benefit (I assume because this is only part of your real code).
Note though that during the backward pass, Tensors will still be created to store the gradients of the intermediary ops).

Again, thank you for the quick reply. The code sample is for manually calculating the gradient layer by layer on the backward. I wonder if you have other better way to achieve this goal. I did lots of searches on how to enable this, but this is the only way that I found out. It would be greatly appreciated if you could give some possible hints.

In addition, you mentioned that tensors still be created to storage the gradients. I wonder if you mean that
(1) the outputs on user-defined tensor are actually copied from some internal result tensors,
(2) this also applies to the regular .backward() case (meaning that we don’t pre-assign tensor to .grad but let this being done automatically),
(3) this is not avoidable since we cannot directly write the results in-place.

Yes the code looks good for that! What I meant is that pre-allocating the .grad as done in this particular sample is not very useful. But only that part, the rest of the code is :smiley:

For your questions:

  1. Yes
  2. Yes
  3. Yes
    :slight_smile:

The last point is the most important here.
Changing this would mean changing one key invariant of the autograd: backward functions never change their inputs inplace.

Thank you so much!!! You are awesome!

My question is solved.