A few questions:
Is the interpolation being done on the GPU or CPU?
If the model is being done on the GPU, have you explored jit scripting the model e.g., torch.jit.script — PyTorch 1.13 documentation as it could potentially fuse the pointwise operations done before the interpolation.
If you can tolerate a bleeding-edge user experience, I would also check if Torch 2.0’s compile function could also fuse some pointwise ops and offer some speedup as well: