BatchNorm2d making training extremely slow

I’m implementing SuperPoint network and the backbone uses a VGG-style architecture. I am unable to obtain the same results in my implementation as the paper. There were some unofficial implementations [1] [2] that have used batch normalization int he backbone. I tried implementing the same and started training on MS-COCO dataset. The training speed of model with BN is around 0.3 it/s whereas that without BN is over 10 it/s. I am using Pytorch Lightning to train with two A6000 Ada GPU’s . I’m not sure where I am going wrong or why BatchNorm2d is causing this delay.

Can anyone explain conceptually why the delay is caused?

Original: GitHub - magicleap/SuperPointPretrainedNetwork: PyTorch pre-trained model for real-time interest point detection, description, and sparse tracking (