Customizing torch.autograd.Function

Hi there, hope all of you are fine.
I am working on VQGAN+CLIP, and there they are doing this operation:

class ReplaceGrad(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x_forward, x_backward):
        ctx.shape = x_backward.shape
        return x_forward

    @staticmethod
    def backward(ctx, grad_in):
        return None, grad_in.sum_to_size(ctx.shape)


class ClampWithGrad(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, min, max):
        ctx.min = min
        ctx.max = max
        ctx.save_for_backward(input)
        return input.clamp(min, max)

    @staticmethod
    def backward(ctx, grad_in):
        (input,) = ctx.saved_tensors
        return (
            grad_in * (grad_in * (input - input.clamp(ctx.min, ctx.max)) >= 0),
            None,
            None,
        )


replace_grad = ReplaceGrad.apply
clamp_with_grad = ClampWithGrad.apply

I am trying to understand this, but I cannot find any resources explaining it.
I would be very grateful if anyone can help me with this. Like what is ReplaceGrad doing?

Thanks, for any help!

Hi,

A custom Function is simply a way to provide the backward formula directly.
The Replace grad takes a and b as input and returns a. And in the backward, it will pass all the gradients toward b and no gradient toward a.

2 Likes

Hi @albanD , hope you are fine.
Thanks for helping me out.
Sorry, I am not able to fully understand this. Can you share me some resources, where I can understand it…?

Thanks

The documentation has an example for a simple exponential function here

Another example from the forums is here, although it’s for a custom Double Backward (so 2nd derivative) it might help illustrate what the custom torch.autograd.Function does and what you need code in for it to work.

1 Like

Thanks, will look into it.

Sorry, I am not able to get the concept.
I would be grateful, if you can explain me this piece of code:

class ClampWithGrad(torch.autograd.Function):
  @staticmethod
  def forward(ctx, input, min, max):
    ctx.min = min
    ctx.max = max
    ctx.save_for_backward(input)
    return input.clamp(min, max)

  @staticmethod
  def backward(ctx, grad_in):
    (input,) = ctx.saved_tensors
    return (
        grad_in * (grad_in * (input - input.clamp(ctx.min, ctx.max)) >= 0),
        None, None,
    )

Like how is it trying to change grad_in, and why we have this ,None, None ?

Thanks

Ok, so after having a look on the forums there’s a link to how torch.clamp is implemented and it’s located here. And, I checked the gradient for that custom function and I’m pretty sure it’s wrong!

With regards to what torch.autograd.Function does, it’s a way (as @albanD said) to manually tell PyTorch what the derivative of a function should be (as opposed to getting the derivative automatically from Automatic Differentiation).

This version of the function is below

class ClampWithGradThatWorks(torch.autograd.Function):
  @staticmethod
  def forward(ctx, input, min, max):
    ctx.min = min
    ctx.max = max
    ctx.save_for_backward(input)
    return input.clamp(min, max)

  @staticmethod
  def backward(ctx, grad_out):
    input, = ctx.saved_tensors
    grad_in = grad_out*(input.ge(ctx.min) * input.le(ctx.max))
    return grad_in, None, None

I’ve renamed some variables to more appropriate names (like grad_in to grad_out).

The forward method of a function is that actual output the function returns, the backward method is where the gradients are calculated.

The grad_out term is the gradient of the Loss with respect to the output of the function, hence why I named it grad_out, and in order to get the gradient for the input terms (or any other terms within the function) we need to calculate the derivative of the output with respect to the other terms. So, in this case we need the gradient of the output with respect to

  1. input
  2. min
  3. max

As the min and max are independent/fixed of the output their gradient is zero (which can be represented by None). So, that explains the None, None term in the output. But what about the grad_in = grad_out*(input.ge(ctx.min) * input.le(ctx.max)) term?

Well, if you plot the clamp function out it’s effectively a linear function between the min and max and min for any input less than min, and max for any value greater than max. So, looking at the gradient it’s 0 for an input less than min and 0 for an input greater than max.

For any input in between min and max it’s 1. So, we need to define the gradient of this ‘in-between’ section as 1. This can be done by multiplying input.ge(ctx.min) with input.le(ctx.max). This also defines the gradient of the previous 2 sections as well (gradient = 0 for input < min and input > max).

We’re nearly done, once we’ve define the gradient of the output with respect to the inputs, we need to connect the gradient of the loss with respect to the output (grad_out) in order to get the gradient of the loss with respect to the input which is grad_in. This can simply be done by using the chain rule which is defined by,

    grad_in = grad_out*(input.ge(ctx.min) * input.le(ctx.max))

Hopefully, this explains what torch.autograd.Function is doing! :slight_smile:

2 Likes

Thanks a loooot.
It is so great of you, to explain in such detail.
I will try out some examples, to get the most of it.

Thanks again…

1 Like

ctx(in forward, and backward), is just acting like self here. Or does it has to do something special?

Kinda, except it has some methods of its own (more info here)

2 Likes

Hi,

 @staticmethod
  def backward(ctx, grad_out):
    input, = ctx.saved_tensors
    grad_in = grad_out*(input.ge(ctx.min) * input.le(ctx.max))
    return grad_in, None, None

Here grad_out, is what returned by forward?

grad_out is returned via the autograd engine and is the gradient of the loss with respect to the output of your function, hence the name grad_out.

2 Likes