Select certain indices from a tensor

I have two tensors A and B. A is a float tensor with shape (batch size, hidden dim). B is a Long tensor with shape (batch size, data len). What I want is somewhat like A[:, B], a float tensor still with shape (batch size, data len), the elements are certain indices from A which depends on B.

An example would be A=[[5, 2, 6], [7, 3, 4]] and B=[[0, 2, 1, 1], [2, 2, 1, 0]]. Then what I want is a tensor [[5, 6, 2, 2], [4, 4, 3, 7]]. Is there any way to achieve this?

I have tried A[:, B], but what I achieve is a tensor with shape (batch size, batch size, data len), which is a large tensor. And I only want the “diagonal value” from this tensor.

Hi,

I think you want: A.gather(1, B).

1 Like

Thanks for the reply! How about when A has another dimension, now the shape of A is (batch size, hidden dim, data dim). And B is still the same, a long tensor with shape (batch size, data len). I actually want to have the resulting tensor with shape (batch size, data len, data dim).

For example, A=[[[5], [2], [6]], [[7], [3], [4]]] and B=[[0, 2, 1, 1], [2, 2, 1, 0]]. Then what I want is a tensor [[[5], [6], [2], [2]], [[4], [4], [3], [7]]]. Basically now data dim is 1 as an example.

It seems that A.gather(1, B) requires A and B to have the same shapes.

If you add extra dimensions, you can simply expand B in these dimensions:

A.gather(1, B.unsqueeze(-1).expand(batch, data_len, data_dim)