How can I modify a ReLU layer's backward?

I want to modify the backward of relu, such that i simply pass through the gradients coming from the top rather than 0-ing out the ones where the unit is off.

I looked at this thread and couldn’t get much out of it. Looking here, while this works for making the gradients zero, i am not sure what the components of gradient_input are and which one i should modify to have a pass-through.

You can use the backward hooks on the relu Module:

import torch
import torch.nn as nn

inp = torch.randn(5, requires_grad=True)
print(inp) # [ 1.7485,  0.5180, -1.2110, -1.3865, -0.6293]

relu = nn.ReLU()

# Without hook:
out = relu(inp)
out.sum().backward()
print(inp.grad) # [1., 1., 0., 0., 0.]

inp.grad.zero_()
# With hook
def grad_hook(mod, grad_in, grad_out):
    # Ignore the computed grad_input
    # Make it look like this is an identity
    return grad_out
relu.register_backward_hook(grad_hook)

out = relu(inp)
out.sum().backward()
print(inp.grad) # [1., 1., 1., 1., 1.]
1 Like

thanks for the reply, from this post however it seems like what one has to return is modification of grad_input, where in the snippet in your post, we seem to be returning a modification of grad_out.

That depends what you want to do.
If you want your relu’s backward to act as if it was an identity, then the gradient of the output wrt to input is 1, and so what you want your backward to compute is grad_output * 1. Ignoring the grad_input computed by the backward of the relu as you replace it.

yeah i see your point, for a pointwise activation grad_input and grad_output would be of the same size, so its straightforward to directly use the grad_output to substitute grad_input.

An extra question please, form other posts it seems that modification of gradients is not so straightforward for other layers, is that still the case?