Here I got two tensor, A and it’s corresponding index tensor, for example
A = torch.tensor([2,3,41,5,12,45,4576,2,1])
index = torch.tensor([0,0,1,1,2,2,2,2,3])
What I want is a list of lists, for example
Res = [
[2,3],
[41,5],
[12,45,4567,2],
[1]
]
that Res[i] is a list contains all the elements in A which corresponding index is i.
Actually I can use a loop to get this, but I think there may be a better way to do this in torch (loop seems slow without vectorization)
Is there any better way to do it? I can get the max length of list, so some solutions with fix-length is also good.
(Briefly speaking, what I want is torch_scatter.segment_coo without reduce)
Thanks in advance for any help or suggetions.