Building a custom operator using two pytorch ops

I have the following code in my nn.Module.

x = torch.cdist(a,b)
y = torch.softmax(x)

These two are differentiable, but due to the size of x, x/y need a lot of GPU memory, causing a OOM during backprop. Since I’m not sure if the checkpoint can help this (getting NaN with ddp) and x/y are very sparse (many near zero values), I was thinking of a custom operator in the following fashion, but cannot find a doc on how to call backward of existing pytorch constructs in a custom operator. Can anyone point me out some examples on how to complete this?

class MyFunc(torch.autograd.Function):
    @staticmethod
    def forward(ctx, a, b):        
        x = torch.cdist(a,b)
        ctx.save_for_backward(a,b, sparse(x))
        return torch.softmax(x)

    @staticmethod
    def backward(ctx, g):
     ??

Have you looked into saved tensor hooks? Saved tensor hooks let you interpose behavior when the tensor is “packed” to be saved for backward and when it is “unpacked” to be used during the backward pass. Since GPU memory is your issue, you can consider using the context manager torch.autograd.graph.save_on_cpu (which utilizes saved tensor hooks in its implementation). In this context all intermediary results are moved to cpu, and copied back to the original device when it is used during the backward pass.

https://pytorch.org/docs/stable/autograd.html?highlight=saved%20tensor%20hooks#torch.autograd.graph.saved_tensors_hook

To directly answer your question though, I believe some of the backwards are available in the torch.ops.aten namespace, e.g., torch.ops.aten._softmax_backward_data, torch.ops.aten._cdist_backward. The full list is available in the Pytorch source at aten/src/ATen/native/native_functions.yaml. This is NOT intended to be public api though, and is subject to change.

Actually, the 1st option looks better and more sustainable than the 2nd. I will try it out and report back here for future references, thanks!

1 Like

I gave a shot to the 1st option (save_on_cpu), and it works but with huge throughput penalty from cpu-gpu tensor copy (on pcie). I will look into the saved tensor hook.