Pytorch 2-D and 3-D Table Interpolation

I am trying to get 2-D and 3-D interpolation table lookup running in pytorch, but I don’t believe torch.lerp supports it and haven’t been able to find any other pytorch native solution. Has anyone done a neural network approximation of a 2-D or 3-D linear interpolation table in pytorch? I am wondering if it is a worthwhile endeavor for speed at inference time (especially if the required size of NN to approximate a table is large) or if there is a simpler/more standard way that I missed.

For some background: I am trying to encode a known physical relationships that is in the form a linear interpolation table into a LSTM time series forecast system I am designing. The interpolation table queries would be done during training, so this isn’t simply a pre-processing step.