Lets say I have an input tensor input = torch.tensor([1,2,3,4,5]), I want an out tensor output which looks like [[2,3],[4,2],[3,1]], or any other crazy dimensional ?

So far I see that this works:

t = torch.Tensor([[1,2],[3,4]])

c = torch.gather(t, 0, torch.LongTensor([[0,0],[1,1]]))

c = [[1, 4],[3, 4]]

But I am not able to get why this is the output, or why the index has to be of the same dimension and size as of input.

I have worked with tf.gather(), and I know how it works. Can somebody explain the logic for torch.gather() ?