In a current project of mine I need to project the weights of my convolutions. To do that I used my own module that uses `torch.nn.functional.conv2d`

instead of `torch.nn.Conv2d`

However I realized that it is very slow when I use the model on a GPU instead of a CPU.

Even more surprising, when I just use conv2d outside of the module (with the same weights) it is also a lot faster.

Here is the minimal code I used to produce this problem:

```
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
class PytorchConvMatrixModule(nn.Module):
def __init__(self,kernel_size):
super().__init__()
self.conv_weights = nn.Parameter(torch.randn((1, 1, kernel_size, kernel_size)))
def forward(self, x, to_transpose):
if(to_transpose):
return F.conv_transpose2d(x, self.conv_weights)
else:
return F.conv2d(x, self.conv_weights)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
conv_module_gpu = PytorchConvMatrixModule(7)
conv_module_cpu = PytorchConvMatrixModule(7)
conv_module_cpu.conv_weights = nn.Parameter(conv_module_gpu.conv_weights.clone())
conv_module_gpu.to(device)
x = (torch.rand([24,1,48,48])*2 - 1).to(device)
start = time.time()
y_gpu = conv_module_gpu(x, to_transpose=False)
print("Time for convolution on GPU:", time.time() - start)
start = time.time()
y_cpu = conv_module_cpu(x.cpu(), to_transpose=False)
print("Time for convolution on CPU:", time.time() - start)
weights = conv_module_gpu.conv_weights
start = time.time()
y_direct = F.conv2d(x, weights)
print("Time for convolution on GPU without model: ", time.time() - start)
```

I then get the following output:

Time for convolution on GPU: 1.6394927501678467

Time for convolution on CPU: 0.002363443374633789

Time for convolution on GPU without model: 0.0002646446228027344

I’m really surprised by this since I (at least in my opinion) did not change anything too significant to get such a decrease in speed. The input tensor is just uniformly distributed between -1 and 1.

I used pytorch version 2.0.1 and my GPU is a titan RTX.