Simple question about the backward() behavior of a custom extension

Suppose that I have a CUDA extension with which I can define a custom autograd function like the following:

import my_cuda_extension

class MyFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, y):
        ctx.y = y
        return my_cuda_extension.forward(x, y)
        
    @staticmethod
    def backward(ctx, x):
       result = my_cuda_extension.backward(x, ctx.y)
       return result, torch.zeros_like(ctx.y)

Now, suppose that my CUDA extension only accepts the NHWC data format. Thus, when defining my custom layer, I have to convert the NCHW inputs to NHWC, then convert back the results to NCHW:

class MyLayer(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x, y):
        # convert the inputs to NHWC
        result = MyFunction.apply(x.permute(0,2,3,1).contiguous(),
                                  y.permute(0,2,3,1).contiguous())
        # convert the result back to NCHW
        return result.permute(0,3,1,2)

The forward pass should clearly work. But how about the backward pass? Will I get the correct gradients? Because I don’t see how the backward pass is involved in the above definition of MyLayer.

Thank you in advance for your answer!

Hi,

All the code outside the custom Function is autodiffed. And so gradient will be computed for each op you do. So all will work fine here!

Note that in the custom function, you should not save inputs in the context but use ctx.save_for_backward(y) and y, = ctx.saved_tensors to avoid memory leak :slight_smile:

1 Like

Sorry for the late reply. Thanks a lot for your answer! That totally makes sense.

Regarding the use of ctx.save_for_backward, according to this comment, memory leak doesn’t seem to occur starting from PyTorch 0.4. Is that true? I’ll use ctx.save_for_backward anyway because it seems to be a safe option, thanks for the tip!

Regarding the use of ctx.save_for_backward , according to this comment, memory leak doesn’t seem to occur starting from PyTorch 0.4. Is that true?

That is for intermediary results, not input/output. You have to use save_for_backward() for these.

Thanks for the confirmation!