Changing dtype drastically affects training time

I’ve implemented a Wide Resnet and I’m trying to fine-tune it on Cifar10. The model’s last layer is a linear layer, and before that it’s:
out = torch.mean(out, dim=(2,3), dtype=torch.float64)
When the dtype was float32, each epoch was taking around 15 seconds at most, but now it’s taking almost 4 minutes which is a drastic increase. Any idea why?

Could you post a minimal and executable code snippet reproducing the slowdown, please?