How to split a Tensor based on class id

Yes, that’s possible with a direct comparison:

x = torch.tensor([0, 1, 1, 0])
idx0 = x == 0
idx1 = x == 1

x[idx0]
# tensor([0, 0])
x[idx1]
# tensor([1, 1])
1 Like