Turns out the it does work and this error indicates that the shapes are not correct. Seems like torch.nn.grad.conv2d_weight
is a bit more forgiving in handling wrong shapes.
I tried:
grad_output shape: torch.Size([1, 32, 46, 46])
input shape: torch.Size([1, 128, 49, 49])
But it should have been (with padding 0):
input shape: torch.Size([1, 128, 48, 48])
cudnn_convolution_backward_weight
is about 3x faster than torch.nn.grad.conv2d_weight
in my case