gc625
December 10, 2022, 1:37am
#1
I have an idx
tensor with shape torch.Size([2, 80000, 3])
, which corresponds to batch, points, and indices of 3 elements (from 0 to 512) from the feat
tensor with shape torch.Size([2, 513, 2])
.
I cant seem to find a way to use torch.gather
to index feat
with idx
with a resulting tensor with shape [2,8000,3,2]
.
ptrblck
December 10, 2022, 4:27am
#2
Unfortunately you haven’t shared a reference code (even if slow) which I could use to compare my approach against, but this might work:
idx = torch.randint(0, 513, [2, 80000, 3])
x = torch.randn(2, 513, 2)
out = x[torch.arange(x.size(0))[:, None, None], idx]
print(out.shape)
# torch.Size([2, 80000, 3, 2])
1 Like
gc625
December 10, 2022, 5:15am
#3
ah apologies. After awhile I did manage to find a way to do it (I think), although a bit more clunkly than your implementation:
idx = torch.randint(0, 513, [2, 80000, 3])
feats = torch.randn(2, 513, 2)
B,N,K = idx.shape
temp_idx = idx.reshape(B,N*K,1)
gather_idx = torch.concat((temp_idx,temp_idx),dim=2)
gathered_features = torch.gather(feats,1,gather_idx)
final_feats = gathered_features.reshape((B,80000,3,2))
any ideas about which approach is better?
ptrblck
December 10, 2022, 5:20am
#4
Both methods should yield the same output, but you might want to profile both approaches.
On the CPU I’m getting a faster result using my direct indexing:
def fun1(x, idx):
out = x[torch.arange(x.size(0))[:, None, None], idx]
return out
def fun2(x, idx):
B,N,K = idx.shape
temp_idx = idx.reshape(B,N*K,1)
gather_idx = torch.concat((temp_idx,temp_idx),dim=2)
gathered_features = torch.gather(x,1,gather_idx)
gathered_features = gathered_features.view(2, 80000, 3, 2)
return gathered_features
idx = torch.randint(0, 513, [2, 80000, 3])
x = torch.randn(2, 513, 2)
out1 = fun1(x, idx)
out2 = fun2(x, idx)
print((out1 - out2).abs().max())
# tensor(0.)
%timeit fun1(x, idx)
# 355 µs ± 97.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit fun2(x, idx)
# 1.17 ms ± 112 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Note that you would need to add synchronizations in case you want to profile the code on the GPU.
1 Like