Tf.gather_nd in pytorch

For the below code i want to extract the tensor “b” ( [0,1] and [1,2] ) from tensor “a” like the function “tf.gather_nd” do in tensorflow. So output should be of tensor 2x3 with tensor value [3,4,5] and [18,19,20]

a = torch.FloatTensor([[[ 0,  1,  2],
         [ 3,  4,  5],
         [ 6,  7,  8],
         [ 9, 10, 11]],

        [[12, 13, 14],
         [15, 16, 17],
         [18, 19, 20],
         [21, 22, 23]]])
print(a.shape)
b = torch.FloatTensor([[0,1],[1,2]])
print(b.shape)

Direct indexing should work:

a = torch.tensor([[[ 0,  1,  2],
                   [ 3,  4,  5],
                   [ 6,  7,  8],
                   [ 9, 10, 11]],

                  [[12, 13, 14],
                   [15, 16, 17],
                   [18, 19, 20],
                   [21, 22, 23]]]).float()

b = torch.tensor([[0, 1], [1, 2]])

print(a[b[0, :], b[1, :]])
> tensor([[ 3.,  4.,  5.],
          [18., 19., 20.]])

(1) I just posted the sample code for query. Actually tensors are on cuda so when i run the below code on CPU, it is just disconnecting from google colab and when i run on GPU it throws error of " CUDA error: device-side assert triggered "

(2) Just assume “tensor a” first dimension as batch number (2 in this case torch.Size([2, 4, 3]. So batch number is 0 and 1 ) and “tensor b” first and second dimension contains the batch number and tensor location respectively which need to be extract (for batch_0 = [0,1],[0,2] , batch_1 = [1, 2] ).

I just want to make it generalize by giving argument a and b only without indexing. Its just like similar to “tf.gather_nd” in tensorflow. Hope i clear my issue.


a = torch.tensor([[[ 0,  1,  2],
         [ 3,  4,  5],
         [ 6,  7,  8],
         [ 9, 10, 11]],

        [[12, 13, 14],
         [15, 16, 17],
         [18, 19, 20],
         [21, 22, 23]]]).float()

b = torch.tensor([[0,1],[0,2],[1, 2]])

a=a.to(0)
b=b.to(0)
print(a[b[0, :], b[1, :]])