Pick certain indices from a tensor

I have two tensors A and B. A is a float tensor with shape (batch size, hidden dim). B is a Long tensor with shape (batch size, data len). What I want is somewhat like A[:, B], a float tensor still with shape (batch size, data len), the elements are certain indices from A which depends on B.

An example would be A=[[5, 2, 6], [7, 3, 4]] and B=[[0, 2, 1, 1], [2, 2, 1, 0]]. Then what I want is a tensor [[5, 6, 2, 2], [4, 4, 3, 7]]. Is there any way to achieve this?

I have tried A[:, B], but what I achieve is a tensor with shape (batch size, batch size, data len), which is a large tensor. And I only want the “diagonal value” from this tensor.

import torch
import torch.nn as nn

A=torch.tensor([[5, 2, 6], [7, 3, 4]]).float()  
B=torch.tensor([[0, 2, 1, 1], [2, 2, 1, 0]]).long()
result = torch.gather(A, 1, B)
print(result)
2 Likes

Thanks for the reply! How about when A has another dimension, now the shape of A is (batch size, hidden dim, data dim). And B is still the same, a long tensor with shape (batch size, data len). I actually want to have the resulting tensor with shape (batch size, data len, data dim).

For example, A=[[[5], [2], [6]], [[7], [3], [4]]] and B=[[0, 2, 1, 1], [2, 2, 1, 0]]. Then what I want is a tensor [[[5], [6], [2], [2]], [[4], [4], [3], [7]]]. Basically now data dim is 1 as an example.

It seems that torch.gather(A, 1, B) requires A and B to have the same shapes.

In this case, just expand B as the same shape of A in order to solve the problem.