How to modify the backward to realize reversible residual network

It needs to use not only the grad_output but also the output to calculate grad, and don’t need to store activation map. But the backward() of Function only have grad_output param.
Reversible Residual Network