Gradient through black box iterative algorithm

Some recent works such as Deep Equilibrium Models use a black-box iterative algorithm that converges to some equilibrium state, and compute gradients at the equilibrium point, without back-propping through all the (arbitrary number of) iterations. What support is there for this in PyTorch?

Hi,

It depends how the back box is implemented.
If it is just pytorch code, then autograd can differentiate through the iterations (though it might be slow).
If you have some properties of the equilibrium that allow you to write dL/dinp = f(dL/dout), then you can use a custom autograd Function to tell the autograd how to differentiate this.