Hi there, hope all of you are fine.
I am working on VQGAN+CLIP, and there they are doing this operation:

``````class ReplaceGrad(torch.autograd.Function):
@staticmethod
def forward(ctx, x_forward, x_backward):
ctx.shape = x_backward.shape
return x_forward

@staticmethod

@staticmethod
def forward(ctx, input, min, max):
ctx.min = min
ctx.max = max
ctx.save_for_backward(input)
return input.clamp(min, max)

@staticmethod
(input,) = ctx.saved_tensors
return (
None,
None,
)

``````

I am trying to understand this, but I cannot find any resources explaining it.
I would be very grateful if anyone can help me with this. Like what is ReplaceGrad doing?

Thanks, for any help!

Hi,

A custom Function is simply a way to provide the backward formula directly.
The Replace grad takes a and b as input and returns a. And in the backward, it will pass all the gradients toward b and no gradient toward a.

1 Like

Hi @albanD , hope you are fine.
Thanks for helping me out.
Sorry, I am not able to fully understand this. Can you share me some resources, where I can understand it…?

Thanks

The documentation has an example for a simple exponential function here

Another example from the forums is here, although it’s for a custom Double Backward (so 2nd derivative) it might help illustrate what the custom `torch.autograd.Function` does and what you need code in for it to work.

1 Like

Thanks, will look into it.

Sorry, I am not able to get the concept.
I would be grateful, if you can explain me this piece of code:

``````class ClampWithGrad(torch.autograd.Function):
@staticmethod
def forward(ctx, input, min, max):
ctx.min = min
ctx.max = max
ctx.save_for_backward(input)
return input.clamp(min, max)

@staticmethod
(input,) = ctx.saved_tensors
return (
None, None,
)
``````

Like how is it trying to change grad_in, and why we have this ,None, None ?

Thanks

Ok, so after having a look on the forums there’s a link to how `torch.clamp` is implemented and it’s located here. And, I checked the gradient for that custom function and I’m pretty sure it’s wrong!

With regards to what `torch.autograd.Function` does, it’s a way (as @albanD said) to manually tell PyTorch what the derivative of a function should be (as opposed to getting the derivative automatically from Automatic Differentiation).

This version of the function is below

``````class ClampWithGradThatWorks(torch.autograd.Function):
@staticmethod
def forward(ctx, input, min, max):
ctx.min = min
ctx.max = max
ctx.save_for_backward(input)
return input.clamp(min, max)

@staticmethod
input, = ctx.saved_tensors

``````

I’ve renamed some variables to more appropriate names (like `grad_in` to `grad_out`).

The `forward` method of a function is that actual output the function returns, the `backward` method is where the gradients are calculated.

The `grad_out` term is the gradient of the Loss with respect to the output of the function, hence why I named it grad_out, and in order to get the gradient for the input terms (or any other terms within the function) we need to calculate the derivative of the output with respect to the other terms. So, in this case we need the gradient of the output with respect to

1. input
2. min
3. max

As the `min` and `max` are independent/fixed of the output their gradient is zero (which can be represented by `None`). So, that explains the `None, None` term in the output. But what about the ` grad_in = grad_out*(input.ge(ctx.min) * input.le(ctx.max))` term?

Well, if you plot the clamp function out it’s effectively a linear function between the `min` and `max` and `min` for any input less than `min`, and `max` for any value greater than `max`. So, looking at the gradient it’s 0 for an `input` less than `min` and 0 for an `input` greater than `max`.

For any input in between `min` and `max` it’s 1. So, we need to define the gradient of this ‘in-between’ section as 1. This can be done by multiplying `input.ge(ctx.min)` with `input.le(ctx.max)`. This also defines the gradient of the previous 2 sections as well (gradient = 0 for `input < min` and `input > max`).

We’re nearly done, once we’ve define the gradient of the output with respect to the inputs, we need to connect the gradient of the loss with respect to the output (`grad_out`) in order to get the gradient of the loss with respect to the input which is `grad_in`. This can simply be done by using the chain rule which is defined by,

``````    grad_in = grad_out*(input.ge(ctx.min) * input.le(ctx.max))
``````

Hopefully, this explains what `torch.autograd.Function` is doing! 1 Like

Thanks a loooot.
It is so great of you, to explain in such detail.
I will try out some examples, to get the most of it.

Thanks again…

1 Like

ctx(in forward, and backward), is just acting like self here. Or does it has to do something special?

Kinda, except it has some methods of its own (more info here)

1 Like

Hi,

`````` @staticmethod
`grad_out` is returned via the autograd engine and is the gradient of the loss with respect to the output of your function, hence the name `grad_out`.