How to exchange pytorch tensor column position?

example: number is the column.

[0,1,2] => [0,2,1]

1 Like

Hi,

You can use torch.gather for that, see doc here.
In you case it should look like torch.gather(2d_input, 1, column_indices).

1 Like

Thanks, I see it. torch.gather is unuseful here I feel.I use torch_select.

t = torch.Tensor([[1,2,3],[3,4,5], [7,9,9]])
print(t)
torch.cat((torch.index_select(t, 1, torch.LongTensor([0,2])), torch.index_select(t, 1, torch.LongTensor([1]))), dim=1)

 1  2  3
 3  4  5
 7  9  9
[torch.FloatTensor of size 3x3]

Out[66]:

 1  3  2
 3  5  4
 7  9  9
[torch.FloatTensor of size 3x3]

In that case, this should be even simpler:
torch.index_select(t, 1, torch.LongTensor([0,2,1]))

3 Likes