I want to modify the backward of relu, such that i simply pass through the gradients coming from the top rather than 0-ing out the ones where the unit is off.
I looked at this thread and couldn’t get much out of it. Looking here, while this works for making the gradients zero, i am not sure what the components of
gradient_input are and which one i should modify to have a pass-through.
You can use the backward hooks on the relu Module:
import torch.nn as nn
inp = torch.randn(5, requires_grad=True)
print(inp) # [ 1.7485, 0.5180, -1.2110, -1.3865, -0.6293]
relu = nn.ReLU()
# Without hook:
out = relu(inp)
print(inp.grad) # [1., 1., 0., 0., 0.]
# With hook
def grad_hook(mod, grad_in, grad_out):
# Ignore the computed grad_input
# Make it look like this is an identity
out = relu(inp)
print(inp.grad) # [1., 1., 1., 1., 1.]
thanks for the reply, from this post however it seems like what one has to return is modification of
grad_input, where in the snippet in your post, we seem to be returning a modification of
That depends what you want to do.
If you want your relu’s backward to act as if it was an identity, then the gradient of the output wrt to input is 1, and so what you want your backward to compute is
grad_output * 1. Ignoring the grad_input computed by the backward of the relu as you replace it.
yeah i see your point, for a pointwise activation
grad_output would be of the same size, so its straightforward to directly use the
grad_output to substitute
An extra question please, form other posts it seems that modification of gradients is not so straightforward for other layers, is that still the case?