I’d like to do torch.gather with the following code -
source = torch.tensor([[1,2,3], [4,5,6], [7,8,9], [10,11,12]])
index = torch.tensor([[0], [2], [0]])
# index = torch.tensor([0,2,0])
source.gather(dim=1, index=index)
to get the following output -
tensor([[1],
[6],
[7]])
However, my index array has this form -
index = torch.tensor([0,2,0])
Does, there exist a trivial way to convert torch.tensor([0,2,0]) to torch.tensor([[0], [2], [0]])?