I am trying to perform a complex operation for which my implementation is very inefficient right now.
I have a list of tuples of the same number of tensors of shape (batch, …) on the one hand, and a tensor of shape (batch, 1) representing indexes in the list.
for instance with a batch size of 4:
l = [(t11, t12),(t21, t22),(t31, t32)] # list of tuples of tensors
idx = [0,2,1,1] # index along batch dimension, pointing to positions in the list
I want to create a tuple of tensors like (ti1, ti2), but where the components of ti1 and ti2 along the batch dimension are picked from corresponding tensors according to the index.
The solution I came up with is this one:
mod_tuple = tuple((torch.stack([l[i][itup][ibatch] for ibatch, i in enumerate(idx)]) for itup in range(tuple_size)))
However with a batch size of 128 it takes a veeeeery long time.
Could you help me find an efficient implementation please?
Thanks in advance !