# Multidim Indexing

I’m having a hard time adapting my code to work across a batch example. Previously I had

``````idx.shape == torch.Size([50, 4])
points.shape == torch.Size([400, 3])
# Index
points[idx].shape == torch.Size([50, 4, 3])
``````

So I would sample four, 3D points, from a set of 400 points and have this repeated 50 times. This was just for one observation. Now I would like to run this for each observation in a batch of 256, thus:

``````idx.shape == torch.Size([256, 50, 4])
points.shape == torch.Size([256, 400, 3])

# Doesn't work:
points[idc].shape
# Doesn't work:
torch.gather(points, 1, idx).shape

# Works but I want to avoid the loop
torch.stack([points[i][idx[i]] for i in range(len(points))]).shape == torch.Size([256, 50, 4, 3])
``````

Edit:

With a bit of messing around I have a solution but it doesn’t look pretty:

``````idx.shape == torch.Size([256, 50, 4])
points.shape == torch.Size([256, 400, 3])

points = points.unsqueeze(1).repeat(1, idx.size(1),,1,1)
idx= idx.unsqueeze(-1).repeat(1,1,1, points.size(3))

# torch.Size([256, 50, 400, 3])
print(m.shape)

# torch.Size([256, 50, 4, 3])
print(ids.shape)

torch.Size([256, 50, 4, 3]
print(m.gather(2, ids).shape)
``````

It also seems to run slower:

``````# 3.66 ms ± 6.41 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%%timeit
A = torch.stack([pt_coords[i][idc[i]] for i in range(len(pt_coords))])

# 8.05 ms ± 10.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%%timeit
coords_rep = pt_coords.unsqueeze(1).repeat(1, idc.size(1), 1, 1)
B = coords_rep.gather(2, idc.unsqueeze(-1).repeat(1,1,1, coords_rep.size(3)))
``````