example: number is the column.
1 Like
Hi,
You can use torch.gather
for that, see doc here.
In you case it should look like torch.gather(2d_input, 1, column_indices)
.
1 Like
Thanks, I see it. torch.gather
is unuseful here I feel.I use torch_select
.
t = torch.Tensor([[1,2,3],[3,4,5], [7,9,9]])
print(t)
torch.cat((torch.index_select(t, 1, torch.LongTensor([0,2])), torch.index_select(t, 1, torch.LongTensor([1]))), dim=1)
1 2 3
3 4 5
7 9 9
[torch.FloatTensor of size 3x3]
Out[66]:
1 3 2
3 5 4
7 9 9
[torch.FloatTensor of size 3x3]
In that case, this should be even simpler:
torch.index_select(t, 1, torch.LongTensor([0,2,1]))
3 Likes