Torch.nn.functional.interpolate() slow on MPS device (sometimes)

I’m developing an inference pipeline on MPS, and the last step is a small upscale from a shape of (1, 25, 497, 497) to (1, 25, 512, 512). To do so, I’m using torch.nn.functional.interpolate(). I’ve noticed that this operation takes over 3 seconds on an MPS device (M1 Pro), whereas it takes less than 1/10 of a second on a CUDA device (T4). While a difference in performance is expected across the board between these two GPUs, this difference seems wildly disproportionate.

I’ve also noticed that if I just create an arbitrary tensor of size (1, 25, 497, 497) and perform the same interpolation outside of the network or in a different Jupyter cell, the interpolation only takes 1/10 of a second on MPS.

Does anyone have any idea what’s causing this performance issue or what steps I could take to identify the issue?

@iwasserman I was trying to run an existing model, however it was very slow when using MPS and I also found that torch.nn.functional.interpolate() was the culprit, similar to what you’re describing. It was only during model inference and not when only running the interpolation by itself, same as for you.

Did you manage to find the issue?

No, sorry. I never solved this and have just accepted it as a limitation of the MPS implementation.