Is it safe to modify output's grad and return as input's grad?

Consider the following simple example which defines a Fn taking [x0, x1] and returning [y0 = x0, y1 = x0 + x1].

class Fn(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        # let's suppose x is a size (2,) tensor.
        y = torch.empty(2)
        y[0] = x[0]
        y[1] = x[0] + x[1]
        return y

    @staticmethod
    def backward(ctx, y_grad):
        y_grad[0] += y_grad[1]
        return y_grad

By simple calculations we know that x_grad[0] = y_grad[0] + y_grad[1]; x_grad[1] = y_grad[1]. Now, instead of forming a new tensor x_grad, I just modify y_grad in-place like the example above, and return the modified result as x_grad. Is this a valid approach?

My current understanding is that this is valid. Since y is not a leaf tensor hence it’s grad is short-lived and can be transferred or borrowed. Also I didn’t see any bug / issues in my use cases (with more complicated function expressions of course). That said I’m not fully sure if I’m doing the right thing. Maybe my use cases are just too simple. Is there a scenario that this approach might break?

I’m not sure what your concern is and would guess you are unsure about the inplace operations in the backward?
Autograd is disabled inside an autograd.Function so it’s on you to guarantee all applied operations are correct. E.g. if you would reuse y_grad again in backward the inplace manipulation should be accounted for and expected.

My concern is whether modifying the input in the backward function (here input means y_grad, or output gradient in general) would lead to any side effect I’m not aware of, especially in my case where I modified it and directly returned it as the input’s gradient (x_grad). Inside the function I can guarantee that the logics are all correct and I’m aware that there’s no autograd inside this forward/backward function body. But how about outside of the function? Is it possible that y_grad (grad_output) is used after the backward call by some utility functions that I’m not aware of? Or will it cause issues if fancy controls like retain_grah, retain_grad are used outside of the function body so that some bad interactions happen?

Also I’m reading the extending pytorch doc and just found that it mentioned:

It is important NEVER to modify these in-place.

(here “these” means output gradients). The doc seems to imply that some issue may occur esp. when you try to do a double backward or higher order gradients. (I haven’t figured out what these exactly mean in the pytorch computational graph context). So back to the question: what’s the side-effect of modifying grad_output inplace and return it directly? What bad thing can happen?

That’s a good point and indeed I would guess using the intermediate gradient after the inplace manipulation could yield wrong results.
However, I cannot reproduce it using this simple code:

class Fn(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        # let's suppose x is a size (2,) tensor.
        y = torch.empty(2)
        y[0] = x[0]
        y[1] = x[0] + x[1]
        return y

    @staticmethod
    def backward(ctx, y_grad):
        print("before: ", y_grad)
        y_grad[0] += y_grad[1]
        print("after: ", y_grad)
        return y_grad
    

x = torch.randn(2, requires_grad=True)
out = Fn.apply(x)

out.retain_grad()
out.mean().backward()
# before:  tensor([0.5000, 0.5000])
# after:  tensor([1.0000, 0.5000])
print(out.grad)
# tensor([0.5000, 0.5000])
print(x.grad)
# tensor([1.0000, 0.5000])

I’m sure @albanD would know some examples where these inplace ops might be dangerous.

Many thanks for the digs and trials! I continued the test with the following examples. I found that somehow the retained grad (of out) is not the one passed into the backward call. And the y_grad in the backward call is indeed modified and passed out as x_grad, with no extra copies. Check this code snippet:

class Fn(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        # let's suppose x is a size (2,) tensor.
        y = torch.empty(2)
        y[0] = x[0]
        y[1] = x[0] + x[1]
        return y

    @staticmethod
    def backward(ctx, y_grad):
        print(f"y_grad before: {y_grad} at {y_grad.data_ptr()}")
        y_grad[0] += y_grad[1]
        print(f"y_grad after: {y_grad} at {y_grad.data_ptr()}")
        return y_grad
    

x = torch.randn(2, requires_grad=True)
out = Fn.apply(x)

out.retain_grad()
out.mean().backward()
# y_grad before: tensor([0.5000, 0.5000]) at 140496189355520
# y_grad after: tensor([1.0000, 0.5000]) at 140496189355520
print(f'out.grad: {out.grad} at {out.grad.data_ptr()}')
# out.grad: tensor([0.5000, 0.5000]) at 140496189437440
print(f'x.grad: {x.grad} at {x.grad.data_ptr()}')
# x.grad: tensor([1.0000, 0.5000]) at 140496189355520

By comparing the data_ptrs, we see y_grad is indeed the same tensor as x.grad (after the call), and out_grad is something different, which must have been cloned in advance.

So the observation is that retain_grad might have registered a hook to clone the grad before the next backward pass (my guess) so it doesn’t break this example. @albanD Can you give some other examples where this in-place output-grad modification may break something?

I just read some documents on pytorch double backward and higher order gradient computations (smart approach)! Now I can say that in-place modification of output grad can breakdouble backward in certain scenarios. Consider the following x * x toy example:

class Fn(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        ctx.save_for_backward(x)
        return x * x

    @staticmethod
    def backward(ctx, y_grad):
        x, = ctx.saved_tensors
        x_grad = 2 * x * y_grad
        # y_grad += 1.0  # if we uncomment this line, double backward will break due to y_grad's in-place modification.
        return x_grad
    

x = torch.tensor(2.0, requires_grad=True)
y = Fn.apply(x)

y.backward(create_graph=True)

# Now try some double backward. Will break if we do y_grad += 1.0 in the backward.
x.grad.backward()

That said, the example still seems too naive. It is of course broken since we made an in-place modification of some tensors that are needed for (the double) backward computation. Also, we received a clear error so at least we are aware. What if I don’t care about double backward and higher order grads (I can mark my backward as once_differentiable) in my use case? Will this in-place modification of output-grad still create issues in some scenarios? Especially issues that are silent (as opposed to the case above where we at least get an error) ?

Even more interesting cases! Now I’m lost by this simple Clone gradient test:

import torch

class Clone(torch.autograd.Function):
  @staticmethod
  def forward(ctx, x):
      y = x.clone()
      return y

  @staticmethod
  def backward(ctx, y_grad):
    print(f'inside backward, y_grad.stride(): {y_grad.stride()}')
    print(f'inside backward, y_grad.data_ptr(): {y_grad.data_ptr()}')
    return y_grad  # to examine the side effect

# case 1
x = torch.rand((2, 2), requires_grad=True)
y = Clone.apply(x)
y.mean().backward()
print(f'done backward, x.grad.stride(): {x.grad.stride()}')
print(f'done backward, x.grad.data_ptr(): {x.grad.data_ptr()}')
# out:
#   inside backward, y_grad.stride(): (2, 1)
#   inside backward, y_grad.data_ptr(): 140442810773632
#   done backward, x.grad.stride(): (2, 1)
#   done backward, x.grad.data_ptr(): 140442810773632
# This case is easy to understand. y_grad is simply passed out.

# case 2
x = torch.rand((2, 2), requires_grad=True)
y = Clone.apply(x)
y.sum().backward()
print(f'done backward, x.grad.stride(): {x.grad.stride()}')
print(f'done backward, x.grad.data_ptr(): {x.grad.data_ptr()}')
# out:
#   inside backward, y_grad.stride(): (0, 0)
#   inside backward, y_grad.data_ptr(): 140442883547904
#   done backward, x.grad.stride(): (2, 1)
#   done backward, x.grad.data_ptr(): 140442883507008
# Interesting case, but still understandable.
# The grad of sum gives a dummy scalar-expanded tensor with stride (0, 0),
# which forces x.grad to be copy-initialized with a new tensor to match x's memory layout
# see:
# https://pytorch.org/docs/stable/autograd.html#default-gradient-layouts

# case 3:
x = torch.rand((2, 2), requires_grad=True)
y = Clone.apply(x)
y_grad = torch.rand_like(y)
print(f'before backward, y_grad.stride(): {y_grad.stride()}')
print(f'before backward, y_grad.data_ptr(): {y_grad.data_ptr()}')
y.backward(y_grad)
print(f'done backward, x.grad.stride(): {x.grad.stride()}')
print(f'done backward, x.grad.data_ptr(): {x.grad.data_ptr()}')
# out:
#   before backward, y_grad.stride(): (2, 1)
#   before backward, y_grad.data_ptr(): 140442809548672
#   inside backward, y_grad.stride(): (2, 1)
#   inside backward, y_grad.data_ptr(): 140442809548672
#   done backward, x.grad.stride(): (2, 1)
#   done backward, x.grad.data_ptr(): 140442809838848
# This one is the really interesting one!
# What makes torch think that x.grad needs a new copy assignment, as opposed to case 1?
# Looks like torch figured out some underlying data ownership and made such a decision,
# but how, why, and in what scenario will it do so?

The case 3 is quite interesting. To my knowledge this involves some object lifetime and ownership stuff. If it’s in C++ I can explain. The backward input y_grad is std::move’d to x_grad in case 1 and case 3, but in case 1 no other objects hold the underlying data so torch can decide that x.grad uniquely owns it and no extra copy is needed. while in case 3 there’s another named object (the global variable y_grad) holding the same data, making the data reference counter > 1, hence torch decide to make an extra copy so that x.grad can uniquely own the data. BUT that’s all C++. In these python api’s, how this is done exactly? My naive understanding of python is that it uses a global garbage collector to manage everything and it’s not guaranteed which object is tmp and can be killed immediately from the stack. Today this example tells that I’m likely wrong. So python also have some local variable management on the memory stack like C++ to make similar object ownership transfer stuff? Otherwise how the above is done?

Ho I’m a bit late to this party :smiley:

It is important NEVER to modify these in-place.

This is for a few reasons:

  • Indeed when double backward is involved, it might prevent the double backward from running if that value was needed for backward.
  • More importantly, that Tensor is used as-is for anything that requires that value of the gradients. In particular (a + b), then ga and gb will be the same Tensor. If you modify ga inplace before gb is used, it will be wrong!

I haven’t looked at your code in great details but I think a lot of the weirdness you see comes from what AccumulateGrad is doing (the node responsible for populating the .grad field).
It will make sure that no two .grad field share the same Tensor. And so, even if during the backward Tensors where the same, it will make a clone() before populating the .grad field to ensure they’re all different.
You can use hooks or autograd.grad to look at these values without the side effect accumulate grad.

Got it. Many thanks for the detailed explanations!