I’m implementing a reversible convolution layer, so that forward pass does not need to save x
.
I have the inverse function already implemented. I’m wondering how do I reuse the original implementation of conv backward, without “re-implementing” it in python, but provide ctx
manually this time.
Or is there any way to add “hooks” to these pytorch layers, that will “remove” certain ctx
immediately after forward, and “insert” it back before backward