In that case, this should be even simpler:
torch.index_select(t, 1, torch.LongTensor([0,2,1]))
2 Likes
In that case, this should be even simpler:
torch.index_select(t, 1, torch.LongTensor([0,2,1]))