a = torch.tensor([[1,2,3,4],[5,6,7,8]])
idx = torch.tensor([[0,2,1],[2,3,0]])
# How to do it in batch ?
c_1 = a[0][idx[0]].view(1,-1)
c_2 = a[1][idx[1]].view(1,-1)
c = torch.cat((c_1, c_2), dim=0)
The desired output is:
tensor([[1, 3, 2],
[7, 8, 5]])
I tried a[idx]
, however, it goes wrong.
LeviViana
(Levi Viana)
May 31, 2019, 12:03pm
2
What is idx
and the desired output in the snippet above ?
Hi, I made a typo, and just made the correction for it.
idx
is the indexes for selecting the elements in the tensor.
LeviViana
(Levi Viana)
May 31, 2019, 12:27pm
5
import torch
a = torch.tensor([[1,2,3,4],[5,6,7,8]])
idx = torch.tensor([[0,2,1],[2,3,0]])
idx2 = idx + torch.arange(idx.size(0)).view(-1, 1) * a.size(1)
c = a.view(-1)[idx2]
It works as long as a
and idx
have only two dimensions and idx.size(0) <= a.size(0)
.
1 Like
Thank you fo you help @LeviViana . It is surely a way to solve it, however, I wonder whether pytorch can support a[idx]
, is it possible @albanD ?
albanD
(Alban D)
June 2, 2019, 2:28pm
7
Hi,
gather
is what your want!
c = a.gather(1, idx)
4 Likes
jia_lee
(Jia Li)
December 31, 2020, 7:08am
8
a_tensor[idx] is supported, I often use this way.