Sudden drop in training speed after a few batches on cuda

pytorch 1.11.0
pytorch-cluster 1.6.0
pytorch-geometric 2.0.4
pytorch-mutex 1.0
pytorch-scatter 2.0.9
pytorch-sparse 0.6.14
pytorch-spline-conv 1.2.1
torchvision 0.12.0
cuda 11.5
cudnn 8.3.2
windows 10

I am training a custom graph neural net with pytorch and am experiencing a sudden drop in training speed by ~ 6x after the first few batches. I used pytorch profiler to profile one “epoch” (amount of data greatly reduced). Everything is on cuda (the data too, I wrote my own dataloader as suggested here). I have 5GB dedicated RAM on my GPU (nvidia Quadro P2200) which is never full (by at least a margin of 30%). The GPU load is also only at ~ 35%.

The sudden drop is not specific to any function. Everything needs ~ 6x the time.

Here is the trace I am producing with the torch profiler:

What could cause the sudden performance drop and how can I circumvent this? Thanks!