Conv2D forward and backward pass very slow

Hello,

I am using python 3.10.13 and pytorch 2.2.2+cu121

I updated a small part of my network and basically changed a MLP with a multi head attention, my input is of shape (Batch, Channel, Sequence_length, Embedding_dim). To do so I had to add 2 Convolutional layer, 1 at the entrance of the Attention layer to collapse the channel dimension into 1 (B, C, S, E) → (B, 1, S, E) → (B, S, E) and a last one to recreate the channel dimension and continue with the rest of the model (B, S, E) → (B, 1, S, E) → (B, C, S, E)

My issue is I noticed a huge slowdown of the network (no matter if gpu/cpu/multi-gpu). On a GPU, an epoch used to take a minute and now takes 40 minutes.
I found out the forward pass is still taking the same amount of the time but the backward pass became 10x slower than the forward instead of 1-2x slower as before. After manually evulating the Conv2D and attention layer I noticed the Conv2D were incredibly much slower despite their much less number of parameters (less than a hundred vs millions for the attention layer)

To get the numbers below I just created the layers and used random inputs and outputs to do the forward pass and to compute a dummy loss for the backward pass.
On CPU, the Conv2D at the entrance (C → 1) takes ~0.004 ms forward pass and takes ~0.040 backward pass so the backward pass is 10x times slower. The CNN at the exit (1 → C) takes ~0.05 ms forward pass and ~0.10 ms backward pass so backward pass is 2 times slower (which is normal) but both forward and backward are very slow and should have been in the order of ~0.001ms ?

The attention layer takes as an input shape (S, E) and both forward/backward pass is ~0.03ms

Both Conv2D don’t even have more than 100 parameters so why does it takes so much time to forward/backward those layers than it takes to forward/backward an attention layer or a mlp with millions of parameters ?

Best

So I tested my code a bit more and apparently the input size is the one and only culprit here

With an input size of (B, 50,4) the backward pass is 1.5x slower than forward pass (so it’s normal).
With an input size of (B,50, 600) the backward pass is 10x slower than forward pass.
My model is just a Unet

Is this 10x slowdown normal ?! I tried torch.compile() and even though it is faster the ratio seems the same.
I’ve trained many different architectures before and it’s the first time I see such a difference between the forward and backward pass, even with the same number of paramaters