How could I do this effectively, without a for loop? I managed to do something alike with torch.kron and simpler tensors, but I struggle doing it with tensors that have those shapes.
This is just what broadcasting is good for (None below inserts a singleton dimension for broadcasting, : uses all of a given dimension. One could use unsqueeze or view, but I like this to align things.)
The latter can be a bit faster (useful if your computation is bottlenecked on these operations), but torch.lerp is a tad exotic and so chances are that people trying to understand the code will have to look up what it does.
One thing to keep in mind - and I did this wrong at first, too - is that if the tensors have a differing number of dimensions, the smaller ones get padded in front. In particular, it is crucial that you give interp the last two dimensions while you could (works, but is less clear IMHO) leave out the front one.