You can do something like this.
cols= torch.tensor([0, 4, 8, 9, 1, 5, 6, 10, 2, 3, 7, 11])
y = torch.gather(mesh.repeat(1, 1, 4), dim=1, index=indices.repeat(1, 1, 3))[:, :, cols]
(I do not know why it has this weird format, but in order to get the same result as yours I had to shuffle the columns like that)
I mean, it does make sense that they are grouped into:
- 0, 4, 8
- 1, 5, 9
- 2, 6, 10
- 3, 7, 11
But why they are always permuting, I cannot answer that.
If you do not care that the results are permuted, then you can use this ↓. The values are the same, just not in the same column order.
y = torch.gather(mesh.repeat(1, 1, 4), dim=1, index=indices.repeat(1, 1, 3))
Test
Just to make sure that the result is the same.
mesh = torch.rand(32, 4096, 3)
indices = torch.randint(0, 4095, (32, 4096, 4))
x = torch.cat(
[torch.gather(mesh, 1, index=indices[:, :, i].unsqueeze(-1).expand(-1, -1, 3)) for i in range(4)], dim=2
)
cols = torch.tensor([0, 4, 8, 9, 1, 5, 6, 10, 2, 3, 7, 11])
y = torch.gather(mesh.repeat(1, 1, 4), dim=1, index=indices.repeat(1, 1, 3))[:, :, cols]
print(x.shape)
print(y.shape)
print(torch.all(x==y))
Result
torch.Size([32, 4096, 12])
torch.Size([32, 4096, 12])
tensor(True)