Customized loss function: how to hardcode the gradient wrt the input?

In a customized loss function the gradient w.r.t the input is automatically traced with auto-grad.

However I’m currently implementing a complicated loss function, eg the wasserstein distance, which
is solved with a iterative algorithm named Sinkhorn Knopp.

I think it’s more efficient to hard-code the gradient rather than computing the gradient with computational graph.

So how can I manually assign the gradient w.r.t input for a customized loss function?