Simple question about the backward() behavior of a custom extension

Suppose that I have a CUDA extension with which I can define a custom autograd function like the following:

import my_cuda_extension

class MyFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, y):
        ctx.y = y
        return my_cuda_extension.forward(x, y)
        
    @staticmethod
    def backward(ctx, x):
       result = my_cuda_extension.backward(x, ctx.y)
       return result, torch.zeros_like(ctx.y)

Now, suppose that my CUDA extension only accepts the NHWC data format. Thus, when defining my custom layer, I have to convert the NCHW inputs to NHWC, then convert back the results to NCHW:

class MyLayer(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x, y):
        # convert the inputs to NHWC
        result = MyFunction.apply(x.permute(0,2,3,1).contiguous(),
                                  y.permute(0,2,3,1).contiguous())
        # convert the result back to NCHW
        return result.permute(0,3,1,2)

The forward pass should clearly work. But how about the backward pass? Will I get the correct gradients? Because I don’t see how the backward pass is involved in the above definition of MyLayer.

Thank you in advance for your answer!

Hi,

All the code outside the custom Function is autodiffed. And so gradient will be computed for each op you do. So all will work fine here!

Note that in the custom function, you should not save inputs in the context but use ctx.save_for_backward(y) and y, = ctx.saved_tensors to avoid memory leak :slight_smile: