Compute interpolation of two tensors

Let’s say I have:

  • two tensors, T0 with size B x a x b and T1 with size B x a x b
  • a list interps of interpolations: [0.25, 0.5, 0.75]

and I want to fill a new tensor T3 with size B x (len(interps) + 2) x a x b so that:

T3[:, 0, ...] = T0
T3[:, 1, ...] = T0 * (1-interps[0]) + T1 * (interps[0])
T3[:, 2, ...] = T0 * (1-interps[1]) + T1 * (interps[1])
T3[:, 3, ...] = T0 * (1-interps[2]) + T1 * (interps[2])
T3[:, 4, ...] = T1

How could I do this effectively, without a for loop? I managed to do something alike with torch.kron and simpler tensors, but I struggle doing it with tensors that have those shapes.

Thanks for the help

This is just what broadcasting is good for (None below inserts a singleton dimension for broadcasting, : uses all of a given dimension. One could use unsqueeze or view, but I like this to align things.)

T0, T1 = torch.randn(2, 3, 4, 4)

interps = torch.tensor([0.0, 0.25, 0.5, 0.75, 1.0])

T3 = T0[:, None, :, :] * (1 - interps[None, :, None, None]) + T1[:, None, :, :] * interps[None, :, None, None]

or equivalently

T3 = torch.lerp(T0[:, None, :, :], T1[:, None, :, :], interps[None, :, None, None])

The latter can be a bit faster (useful if your computation is bottlenecked on these operations), but torch.lerp is a tad exotic and so chances are that people trying to understand the code will have to look up what it does.

Best regards

Thomas

Nice! I tried to do something similar but couldn’t make it work.

Thanks a lot!

One thing to keep in mind - and I did this wrong at first, too - is that if the tensors have a differing number of dimensions, the smaller ones get padded in front. In particular, it is crucial that you give interp the last two dimensions while you could (works, but is less clear IMHO) leave out the front one.

1 Like