Building a custom operator using two pytorch ops

I have the following code in my nn.Module.

x = torch.cdist(a,b)
y = torch.softmax(x)

These two are differentiable, but due to the size of x, x/y need a lot of GPU memory, causing a OOM during backprop. Since I’m not sure if the checkpoint can help this (getting NaN with ddp) and x/y are very sparse (many near zero values), I was thinking of a custom operator in the following fashion, but cannot find a doc on how to call backward of existing pytorch constructs in a custom operator. Can anyone point me out some examples on how to complete this?

class MyFunc(torch.autograd.Function):
    def forward(ctx, a, b):        
        x = torch.cdist(a,b)
        ctx.save_for_backward(a,b, sparse(x))
        return torch.softmax(x)

    def backward(ctx, g):

Have you looked into saved tensor hooks? Saved tensor hooks let you interpose behavior when the tensor is “packed” to be saved for backward and when it is “unpacked” to be used during the backward pass. Since GPU memory is your issue, you can consider using the context manager torch.autograd.graph.save_on_cpu (which utilizes saved tensor hooks in its implementation). In this context all intermediary results are moved to cpu, and copied back to the original device when it is used during the backward pass.

To directly answer your question though, I believe some of the backwards are available in the torch.ops.aten namespace, e.g., torch.ops.aten._softmax_backward_data, torch.ops.aten._cdist_backward. The full list is available in the Pytorch source at aten/src/ATen/native/native_functions.yaml. This is NOT intended to be public api though, and is subject to change.

Actually, the 1st option looks better and more sustainable than the 2nd. I will try it out and report back here for future references, thanks!

1 Like

I gave a shot to the 1st option (save_on_cpu), and it works but with huge throughput penalty from cpu-gpu tensor copy (on pcie). I will look into the saved tensor hook.

I eventually used autograph.graph to implement custom save_tensor_hook with some sort of grad-checkpoint based on the application specifics. It worked very well now, and thanks for the tip on the autograd.graph!

Waiting for @ptrblck response…