I have a tensor of shape P.shape=[N,k]
and an index tensor of shape ind.shape=[L,N]
, where ind[i,j]
is a column in P[j]
(ind[i,j] < k
always).
I wish to create a new tensor of dims [L,n]
, where the functionality can be generated using a for loop in the following manner:
new= []
num_points = P.shape[-1]
for experiment in range(ind.shape[0]):
new.append(P[torch.arange(num_points),ind[exp]])
new= torch.stack(new)
But as L
is really big, the code is extremely slow.
Using repeat
I managed to replicate the functionality
new = P.unsqueeze(1).repeat(1,L,1,1).reshape(-1,*P.shape[1:])
new = new.gather(2,ind.unsqueeze(2)).squeeze(2)
But as L
is really big I have an OOM exception on the .repeat(1,L,1,1)
line.
I was wondering if I can accomplish something similar using broadcasting?