Unable to use `torch.gather` with 3D index and 3D input

I have an idx tensor with shape torch.Size([2, 80000, 3]), which corresponds to batch, points, and indices of 3 elements (from 0 to 512) from the feat tensor with shape torch.Size([2, 513, 2]).

I cant seem to find a way to use torch.gather to index feat with idx with a resulting tensor with shape [2,8000,3,2].

Unfortunately you haven’t shared a reference code (even if slow) which I could use to compare my approach against, but this might work:

idx = torch.randint(0, 513, [2, 80000, 3])
x = torch.randn(2, 513, 2)

out = x[torch.arange(x.size(0))[:, None, None], idx]
print(out.shape)
# torch.Size([2, 80000, 3, 2])
1 Like

ah apologies. After awhile I did manage to find a way to do it (I think), although a bit more clunkly than your implementation:


idx = torch.randint(0, 513, [2, 80000, 3])
feats = torch.randn(2, 513, 2)

B,N,K = idx.shape
temp_idx = idx.reshape(B,N*K,1)
gather_idx = torch.concat((temp_idx,temp_idx),dim=2)
gathered_features = torch.gather(feats,1,gather_idx)
final_feats = gathered_features.reshape((B,80000,3,2))

any ideas about which approach is better?

Both methods should yield the same output, but you might want to profile both approaches.
On the CPU I’m getting a faster result using my direct indexing:

def fun1(x, idx):
    out = x[torch.arange(x.size(0))[:, None, None], idx]
    return out

def fun2(x, idx):
    B,N,K = idx.shape
    temp_idx = idx.reshape(B,N*K,1)
    gather_idx = torch.concat((temp_idx,temp_idx),dim=2)
    gathered_features = torch.gather(x,1,gather_idx)
    gathered_features = gathered_features.view(2, 80000, 3, 2)
    return gathered_features 
    
    
idx = torch.randint(0, 513, [2, 80000, 3])
x = torch.randn(2, 513, 2)

out1 = fun1(x, idx)
out2 = fun2(x, idx)

print((out1 - out2).abs().max())
# tensor(0.)

%timeit fun1(x, idx)
# 355 µs ± 97.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit fun2(x, idx)
# 1.17 ms ± 112 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

Note that you would need to add synchronizations in case you want to profile the code on the GPU.

1 Like