Does anybody have any idea why the interpolate function in PyTorch is so slow?
The forward part of interpolate() is fast, however the backward part is really slow.
If you only use 1 or 2 times of interpolate() you may not notice the speed, but if you use this function 10+ times you’ll find the whole training time doubles (e.g., training imagenet).