For the below code i want to extract the tensor “b” ( [0,1] and [1,2] ) from tensor “a” like the function “tf.gather_nd” do in tensorflow. So output should be of tensor 2x3 with tensor value [3,4,5] and [18,19,20]
a = torch.FloatTensor([[[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]],
[[12, 13, 14],
[15, 16, 17],
[18, 19, 20],
[21, 22, 23]]])
print(a.shape)
b = torch.FloatTensor([[0,1],[1,2]])
print(b.shape)