Ok, so after having a look on the forums there’s a link to how torch.clamp
is implemented and it’s located here. And, I checked the gradient for that custom function and I’m pretty sure it’s wrong!
With regards to what torch.autograd.Function
does, it’s a way (as @albanD said) to manually tell PyTorch what the derivative of a function should be (as opposed to getting the derivative automatically from Automatic Differentiation).
This version of the function is below
class ClampWithGradThatWorks(torch.autograd.Function):
@staticmethod
def forward(ctx, input, min, max):
ctx.min = min
ctx.max = max
ctx.save_for_backward(input)
return input.clamp(min, max)
@staticmethod
def backward(ctx, grad_out):
input, = ctx.saved_tensors
grad_in = grad_out*(input.ge(ctx.min) * input.le(ctx.max))
return grad_in, None, None
I’ve renamed some variables to more appropriate names (like grad_in
to grad_out
).
The forward
method of a function is that actual output the function returns, the backward
method is where the gradients are calculated.
The grad_out
term is the gradient of the Loss with respect to the output of the function, hence why I named it grad_out, and in order to get the gradient for the input terms (or any other terms within the function) we need to calculate the derivative of the output with respect to the other terms. So, in this case we need the gradient of the output with respect to
- input
- min
- max
As the min
and max
are independent/fixed of the output their gradient is zero (which can be represented by None
). So, that explains the None, None
term in the output. But what about the grad_in = grad_out*(input.ge(ctx.min) * input.le(ctx.max))
term?
Well, if you plot the clamp function out it’s effectively a linear function between the min
and max
and min
for any input less than min
, and max
for any value greater than max
. So, looking at the gradient it’s 0 for an input
less than min
and 0 for an input
greater than max
.
For any input in between min
and max
it’s 1. So, we need to define the gradient of this ‘in-between’ section as 1. This can be done by multiplying input.ge(ctx.min)
with input.le(ctx.max)
. This also defines the gradient of the previous 2 sections as well (gradient = 0 for input < min
and input > max
).
We’re nearly done, once we’ve define the gradient of the output with respect to the inputs, we need to connect the gradient of the loss with respect to the output (grad_out
) in order to get the gradient of the loss with respect to the input which is grad_in
. This can simply be done by using the chain rule which is defined by,
grad_in = grad_out*(input.ge(ctx.min) * input.le(ctx.max))
Hopefully, this explains what torch.autograd.Function
is doing!