If it is a custom Function, you need to make sure that its backward is differentiable via autograd.
If you use non-differentiable ops in the backward, you will have to write a second Function whose forward will be the backward of the first one. Like:
class MyFn(Function):
@staticmethod
def forward(ctx, inp):
returm my_non_diff_forward(inp)
@staticmethod
def backward(ctx, gO):
return MyFnBackward.apply(gO)
class MyFnBackward(Function):
@staticmethod
def forward(ctx, inp):
returm my_non_diff_backward(inp)
@staticmethod
def backward(ctx, gO):
return my_diff_double_backward(gO)
Note that if your double backward is not differentiable, you can add a @oncedifferentiable (from torch.autograd.function import oncedifferentiable) to it’s backward to get a nice error if you ever try to backward through that in the future.