Understanding Autograd + ReLU(inplace = True)

How do I compute/check/understand gradients of inplace ReLU? I did clone() to bypass “a leaf Variable that requires grad has been used in an in-place operation”. Gradient should obviously be 0, but I get 1.

import torch
import torch.nn.functional as F

# not in place
x = torch.tensor([-1.0]).requires_grad_().clone()
print(torch.autograd.grad(F.relu(x), (x, ))[0])
# tensor([0.])

# inplace
x = torch.tensor([-1.0]).requires_grad_().clone()
print(torch.autograd.grad(F.relu_(x), (x, ))[0])
# tensor([1.])

My final usecase is testing some ResNet-like code like following:

def residual_activation_inplace(x, residual, activation, inplace):
  for r in residual:
    x += r
  return activation(x, inplace = inplace)

x = torch.rand(3, 4)
residual = [torch.rand(3, 4), torch.rand(3, 4)]
y = residual_activation_inplace(x, residual, F.relu, inplace = True)
1 Like

Hi,

Since you apply the relu inplace in the second case, x now points to the output of the relu. And so you actually do dx/dx = 1.

If you do the following to have access to the gradient of the original x (before the inplace), it will work.

x = torch.tensor([-1.0]).requires_grad_()
x2 = x.clone()
print(torch.autograd.grad(F.relu_(x2), (x, ))[0])

In spirit, you can replace F.relu_(x) by y = F.relu(x); x = y. Where x is the actual python variable and thus will affect everything that uses x.

d(x)/dx = 1 - makes sense :slight_smile:

Hi, guys,

I am confused by the gradient of inplace version ReLU until I found this thread. However, it is still unclear how the backward performs in inplace ReLU.

Base on https://github.com/pytorch/pytorch/blob/master/tools/autograd/derivatives.yaml

name: relu_(Tensor(a!) self) → Tensor(a!)
self: threshold_backward(grad, result, 0)

we can see the inplace ReLU (relu_) replies on the output tensor (result) for its back-propogation, however, this tensor has already been modified to non-negative range. It is confusing how to passing back gradient as we lost the sign of the tensor.

See example below, for inplace ReLU, the saved_tensors in the backward function is non-negative

class MyReLU(torch.autograd.Function):
“”"
We can implement our own custom autograd Functions by subclassing
torch.autograd.Function and implementing the forward and backward passes
which operate on Tensors.
“”"
@staticmethod
def forward(ctx, input):
“”"
In the forward pass we receive a Tensor containing the input and return
a Tensor containing the output. ctx is a context object that can be used
to stash information for backward computation. You can cache arbitrary
objects for use in the backward pass using the ctx.save_for_backward method.
“”"
ctx.save_for_backward(input)
return input.clamp(min=0)
@staticmethod
def backward(ctx, grad_output):
“”"
In the backward pass we receive a Tensor containing the gradient of the loss
with respect to the output, and we need to compute the gradient of the loss
with respect to the input.
“”"
input, = ctx.saved_tensors
grad_input = grad_output.clone()
grad_input[input < 0] = 0 # for inplace version, grad_input = grad_output, as input is modified into non-negative range?
return grad_input

Thus, the only way for correct inplace ReLU backward propogration is to save another tensor which indicate the sign of the input (flag = input < 0)?

In other word, for inplace ReLU, a flag (e. g. flag = input < 0) tensor is saved for backward, rather than the input/output tensor?

Hi,

The gradient we want to compute here is indeed: 1 if input > 0 and 0 if inputs <= 0.
The nice thing is that inputs <= 0 <=> relu(inputs) = 0.
So we can actually compute the gradient based on the result with grad_input[result == 0] = 0 (or with <=, that would give the same result as result >=0).

1 Like

Thank you for the explain. With grad_input[result == 0] = 0, no necessary for extra flag stored. Clear.