Efficient tensor construction from other tensors


I am trying to perform a complex operation for which my implementation is very inefficient right now.

I have a list of tuples of the same number of tensors of shape (batch, …) on the one hand, and a tensor of shape (batch, 1) representing indexes in the list.

for instance with a batch size of 4:

l = [(t11, t12),(t21, t22),(t31, t32)]  # list of tuples of tensors


idx = [0,2,1,1]   # index along batch dimension, pointing to positions in the list

I want to create a tuple of tensors like (ti1, ti2), but where the components of ti1 and ti2 along the batch dimension are picked from corresponding tensors according to the index.

The solution I came up with is this one:

mod_tuple = tuple((torch.stack([l[i][itup][ibatch] for ibatch, i in enumerate(idx)]) for itup in range(tuple_size)))

However with a batch size of 128 it takes a veeeeery long time.
Could you help me find an efficient implementation please?

Thanks in advance !

Would it be possible to create a tensor of l and directly index it e.g. via gather or are all t** tensors differently shaped?

Thanks for the answer. The t1* tensors are possibly differently shaped. Actually (t11, t12) is an observation, the list is a trajectory of observations, and what I try to do is to build a new observation by picking a different point in the trajectory for each element of the batch because only one observation is interesting in each trajectory. Then I can do a single forward pass in my model.

I guess it will be more efficient to do one forward pass per element of the trajectory, and then pick the relevant results instead of what I am trying to do here, I will try that.