I’m having a hard time adapting my code to work across a batch example. Previously I had
idx.shape == torch.Size([50, 4])
points.shape == torch.Size([400, 3])
# Index
points[idx].shape == torch.Size([50, 4, 3])
So I would sample four, 3D points, from a set of 400 points and have this repeated 50 times. This was just for one observation. Now I would like to run this for each observation in a batch of 256, thus:
idx.shape == torch.Size([256, 50, 4])
points.shape == torch.Size([256, 400, 3])
# Doesn't work:
points[idc].shape
# Doesn't work:
torch.gather(points, 1, idx).shape
# Works but I want to avoid the loop
torch.stack([points[i][idx[i]] for i in range(len(points))]).shape == torch.Size([256, 50, 4, 3])
Edit:
With a bit of messing around I have a solution but it doesn’t look pretty:
idx.shape == torch.Size([256, 50, 4])
points.shape == torch.Size([256, 400, 3])
points = points.unsqueeze(1).repeat(1, idx.size(1),,1,1)
idx= idx.unsqueeze(-1).repeat(1,1,1, points.size(3))
# torch.Size([256, 50, 400, 3])
print(m.shape)
# torch.Size([256, 50, 4, 3])
print(ids.shape)
torch.Size([256, 50, 4, 3]
print(m.gather(2, ids).shape)
It also seems to run slower:
# 3.66 ms ± 6.41 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%%timeit
A = torch.stack([pt_coords[i][idc[i]] for i in range(len(pt_coords))])
# 8.05 ms ± 10.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%%timeit
coords_rep = pt_coords.unsqueeze(1).repeat(1, idc.size(1), 1, 1)
B = coords_rep.gather(2, idc.unsqueeze(-1).repeat(1,1,1, coords_rep.size(3)))