I am reading this snippet of code and I don’t quite get the backward static function. Specifically, what is grad_ouput and why do we copy grad_output. I can guess that grad_input is the value that’s store in .grad field of variables that are set requires_grad=True
but how is it related to grad_output?
class MyReLU(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
return input.clamp(min=0)
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
grad_input = grad_output.clone()
grad_input[input < 0] = 0
return grad_input