Sort elements and create subsets of a tensor

I have two tensor

T = torch.tensor([20, 4, 5, 9, 22, 5, 50, 9, 45, 5])
I = torch.tensor([ 0, 2, 3, 1,  1, 3,  1, 3,  0, 0])

Each with same number of elements. Tensor ‘I’ has numbers ranging from 0 to 3. I have to make 4 tensors sampling from the tensor ‘T’ according to numbers in tensor ‘I’,
such that my output would be:

a0 = [20, 45, 5]
a1 = [9, 22, 50 ]
a2 = [4]
a3 = [5, 5, 9]

As my actual tensors are very big, what is the most efficient and fastest way to do it using GPU??

You could try this code and see if it’s faster than your current approach:

t = torch.tensor([[20, 4, 5, 9, 22, 5, 50, 9, 45, 5]]).float()
i = torch.tensor([[0, 2, 3, 1, 1, 3, 1, 3, 0, 0]]).long()

z = torch.zeros(4, t.size()[1])
z.scatter_(0, i, t)

tensors = []
for z_ in z:
    tensors.append(z_[z_.nonzero()])