I have a custom module which aims to try rearranging values of the input in a sophisticated way(I have to extending autograd) .
Thus the double backward of gradients should be the same as backward of gradients, similar with reshape?
If I define in this way in XXXFunction.py:
@staticmethod def backward(ctx, grad_output): # do something to rearrange grad_output to get grad_input return grad_input
The double backward seems not correct. Then how to customize the double backward?