Selecting and concatenating rows

Hi! I have two tensors:

  • values, FloatTensor, shape: [32, 4096, 3], value ∈ ℝ
  • indices, LongTensor, shape: [32, 4096, 4], index ∈ {0…4095}

The indices point to dim 1 in values. How can I create a new tensor of shape: [32, 4096, 12] by selecting and concatenating triplets from the values tensor.

Edit:
I came up with this but it doesn’t seem very efficient:

x =torch.cat(
    [torch.gather(mesh, 1, index=indices[:, :, i].unsqueeze(-1).expand(-1, -1, 3)) for i in range(3)], dim=2
)

You can do something like this.

cols= torch.tensor([0, 4, 8, 9, 1, 5, 6, 10, 2, 3, 7, 11])
y = torch.gather(mesh.repeat(1, 1, 4), dim=1, index=indices.repeat(1, 1, 3))[:, :, cols]

(I do not know why it has this weird format, but in order to get the same result as yours I had to shuffle the columns like that)
I mean, it does make sense that they are grouped into:

  • 0, 4, 8
  • 1, 5, 9
  • 2, 6, 10
  • 3, 7, 11

But why they are always permuting, I cannot answer that.

If you do not care that the results are permuted, then you can use this ↓. The values are the same, just not in the same column order.

y = torch.gather(mesh.repeat(1, 1, 4), dim=1, index=indices.repeat(1, 1, 3))

Test

Just to make sure that the result is the same.

mesh = torch.rand(32, 4096, 3)
indices = torch.randint(0, 4095, (32, 4096, 4))

x = torch.cat(
    [torch.gather(mesh, 1, index=indices[:, :, i].unsqueeze(-1).expand(-1, -1, 3)) for i in range(4)], dim=2
)

cols = torch.tensor([0, 4, 8, 9, 1, 5, 6, 10, 2, 3, 7, 11])
y = torch.gather(mesh.repeat(1, 1, 4), dim=1, index=indices.repeat(1, 1, 3))[:, :, cols]

print(x.shape)
print(y.shape)
print(torch.all(x==y))

Result

torch.Size([32, 4096, 12])
torch.Size([32, 4096, 12])
tensor(True)

Thank you @Matias_Vasquez!

Pytorch indexing (and numpy if it is the same) is a complete mystery to me. I also found this solution:

mesh = torch.rand(32, 4096, 3)
indices = torch.randint(0, 4095, (32, 4096, 4))

x = torch.cat([torch.gather(mesh, 1, index=indices[:, :, i].unsqueeze(-1).expand(-1, -1, 3)) for i in range(4)], dim=2)

cols = torch.tensor([0, 4, 8, 9, 1, 5, 6, 10, 2, 3, 7, 11])
y = torch.gather(mesh.repeat(1, 1, 4), dim=1, index=indices.repeat(1, 1, 3))[:, :, cols]

# Another solution
i = indices.view(indices.size(0), indices.size(1) * indices.size(2))
z = mesh[torch.arange(mesh.size(0)).unsqueeze(-1), i].view(mesh.size(0), mesh.size(1), -1)
##

print(x.shape)
print(y.shape)
print(z.shape)
print(torch.all(x == y))
print(torch.all(x == z))

Results:

torch.Size([32, 4096, 12])
torch.Size([32, 4096, 12])
torch.Size([32, 4096, 12])
tensor(True)
tensor(True)

But I’m not sure which is better, I don’t understand any of them

1 Like