Batch index_select

I have a batch data:

tensor A :

[[[66.,  0.,  0.],
  [77.,  0.,  0.],
  [54.,  0.,  0.],
  [33.,  0.,  0.]],

 [[74.,  0.,  0.],
  [31.,  0.,  0.],
  [43.,  0.,  0.],
  [53.,  0.,  0.]]]

I have a index
tensor B:

[[ 0,  2],
 [ 1,  3]]

How do I extract 0, 2 lines from first sample, and 1, 3 from the second sample to get:

[[[66.,  0.,  0.],
  [54.,  0.,  0.]],

 [[31.,  0.,  0.]
  [53.,  0.,  0.]]]
2 Likes

I found a function called [index_select](https://pytorch.org/docs/stable/torch.html#torch.index_select), but it doesn’t work for batch, because the argument index is a 1-D tensor. How can I do this for batch?

Just find a way to do this using gather:

dummy = B.unsqueeze(2).expand(B.size(0), B.size(1), A.size(2))
out = torch.gather(A, 1, dummy)
8 Likes

Yes, gather is the way to go here !