@hanspinckaers: Thank you following up. it has been a while since I did the benchmark. I recall I was training ResNet18 on Imagenet. Using Pytorch’s torch.nn.grad.conv2d_input(...) and torch.nn.grad.conv2d_weight(...) was probably twice as slow and using twice as much memory than letting PyTorch derive the backward pass of Conv2d automatically.
When I tried to use the method you provided in this link, made things a bit faster, but still much slower than PyTorch’s automatic backward pass.
@fsds: Thanks for your answer. Are you referring to td::tuple<at::Tensor,at::Tensor> cudnn_convolution_backward(...) in:
So I just need to create a Python wrapper to it and invoke it in our backward pass?