Torch.autograd.grad returns gradient with None grad_fn even set create_graph=True

If it is a custom Function, you need to make sure that its backward is differentiable via autograd.
If you use non-differentiable ops in the backward, you will have to write a second Function whose forward will be the backward of the first one. Like:

class MyFn(Function):
    @staticmethod
    def forward(ctx, inp):
        returm my_non_diff_forward(inp)

    @staticmethod
    def backward(ctx, gO):
       return  MyFnBackward.apply(gO)

class MyFnBackward(Function):
    @staticmethod
    def forward(ctx, inp):
        returm my_non_diff_backward(inp)

    @staticmethod
    def backward(ctx, gO):
       return  my_diff_double_backward(gO)

Note that if your double backward is not differentiable, you can add a @oncedifferentiable (from torch.autograd.function import oncedifferentiable) to it’s backward to get a nice error if you ever try to backward through that in the future.