How to provide ctx for torch native functions manually

I’m implementing a reversible convolution layer, so that forward pass does not need to save x.

I have the inverse function already implemented. I’m wondering how do I reuse the original implementation of conv backward, without “re-implementing” it in python, but provide ctx manually this time.

Or is there any way to add “hooks” to these pytorch layers, that will “remove” certain ctx immediately after forward, and “insert” it back before backward

You could call the backward methods via:

torch.nn.grad.conv2d_input
<function torch.nn.grad.conv2d_input(input_size, weight, grad_output, stride=1, padding=0, dilation=1, groups=1)>

torch.nn.grad.conv2d_weight
<function torch.nn.grad.conv2d_weight(input, weight_size, grad_output, stride=1, padding=0, dilation=1, groups=1)>
1 Like