How can I modify a ReLU layer's backward?

I want to modify the backward of relu, such that i simply pass through the gradients coming from the top rather than 0-ing out the ones where the unit is off.

I looked at this thread and couldn’t get much out of it. Looking here, while this works for making the gradients zero, i am not sure what the components of gradient_input are and which one i should modify to have a pass-through.

You can use the backward hooks on the relu Module:

import torch
import torch.nn as nn

inp = torch.randn(5, requires_grad=True)
print(inp) # [ 1.7485,  0.5180, -1.2110, -1.3865, -0.6293]

relu = nn.ReLU()

# Without hook:
out = relu(inp)
out.sum().backward()
print(inp.grad) # [1., 1., 0., 0., 0.]

inp.grad.zero_()
# With hook
def grad_hook(mod, grad_in, grad_out):
    # Ignore the computed grad_input
    # Make it look like this is an identity
    return grad_out
relu.register_backward_hook(grad_hook)

out = relu(inp)
out.sum().backward()
print(inp.grad) # [1., 1., 1., 1., 1.]
1 Like

thanks for the reply, from this post however it seems like what one has to return is modification of grad_input, where in the snippet in your post, we seem to be returning a modification of grad_out.

That depends what you want to do.
If you want your relu’s backward to act as if it was an identity, then the gradient of the output wrt to input is 1, and so what you want your backward to compute is grad_output * 1. Ignoring the grad_input computed by the backward of the relu as you replace it.

yeah i see your point, for a pointwise activation grad_input and grad_output would be of the same size, so its straightforward to directly use the grad_output to substitute grad_input.

An extra question please, form other posts it seems that modification of gradients is not so straightforward for other layers, is that still the case?

Hi, thanks a lot for answering this question @albanD. Your example already makes it quite clear for me, however, I do have an additional question just to clarify it completely:
While searching the web, I was not really able to find a clear documentation on the parameters grad_input and grad_output of the “hook function” that is passed to register_backward_hook, beyond what is stated here. This is how I understand it based on your little code example above:

In the backward pass, something like this is computed:

Let x be the randomly initialized tensor and u = ReLU(x)

Then we would have something like:

dL/dx = du/dx * dL/du (by the chain rule)

where L is the “loss” (i.e. the sum of the elements of u)

In this case du/dx should represent grad_input (i.e. the gradient of the layer output w.r.t. the layer input) and dL/du should represent grad_output (i.e. the gradient of the loss w.r.t. the layer output).

Is my interpretation correct?

(PS: Sorry for the math formatting, Latex is not supported. I hope it still becomes clear.)