Select columns from a batch of matrices by index

Let’s consider

a = torch.arange(8).view(2,2,2)

which should look like

[[[0, 1],
[2, 3]],

[[4, 5],
[6, 7]]]

and also

index = torch.tensor([1, 0])

“index” shows column indexes which should be extracted from “a”.
so the output should be look like this:

[[1, 4],
 [3, 6]]

although it seems not difficult, using torch.gather and other indexing methods couldn’t help me. Any suggestions?

You could directly index the tensor and permute the output:

res = a[torch.arange(2), :, index]
> tensor([[1, 3],
          [4, 6]])
> tensor([[1, 4],
          [3, 6]])
1 Like

Oh. Nice. I tried multiplying ‘a’ by one hot vectors:

index_tensor = index.view(*index.size(), -1)
one_hot = torch.zeros(*index.size(), 2, dtype=index.dtype)
one_hot = one_hot.scatter(1, index_tensor, 1).unsqueeze(-1)
result = a.matmul(one_hot).squeeze(-1).transpose(-2, -1)

Thinking too complicated always makes trouble!!
Thank you


I am a bit confused with understanding the autograd implications of this operation. I am essentially trying to do this on the outputs of an LSTM. So, was wondering if this doesn’t break the computational graph. Thank you.

Assuming you are concerned about the indexing and transpose operation: no, these operation won’t break the computation graph as seen here:

a = torch.arange(8).view(2,2,2).float().requires_grad_()
index = torch.tensor([0, 1])
res = a[torch.arange(2), :, index]
res = res.t()
> tensor([[[0.2500, 0.0000],
           [0.2500, 0.0000]],

          [[0.0000, 0.2500],
           [0.0000, 0.2500]]])

The gradient will be calculated for the selected elements properly.

1 Like

Thanks for the clarification, I was concerned about the indexing part but now I am clear.