Lets say I have an input tensor input = torch.tensor([1,2,3,4,5]), I want an out tensor output which looks like [[2,3],[4,2],[3,1]], or any other crazy dimensional ?
So far I see that this works:
t = torch.Tensor([[1,2],[3,4]])
c = torch.gather(t, 0, torch.LongTensor([[0,0],[1,1]]))
c = [[1, 4],[3, 4]]
But I am not able to get why this is the output, or why the index has to be of the same dimension and size as of input.
I have worked with tf.gather(), and I know how it works. Can somebody explain the logic for torch.gather() ?