How to split a Tensor based on class id

Hi All,

Let’s say I have a Tensor of shape [batch_size, num_input] where I have 2 classes for my input data. How could I get the indices for each of the classes?

batch_size=100
num_input=4
 
data = torch.randint(0,2,size=(batch_size,num_input)) #fake data has 0 and 1 as the two classes.

Let’s say I had one example that was [0,1,1,0], I’d want a function that returns [0,3] for class_id = 0 or [1,2] for class_id = 1. This is with the idea I could have idx0 and idx1 and then simply do data0 = data[idx0] and data1 = data[idx1]. Is this possible?

I’ve had a look at doing (data==0).nonzero(as_tuple=True), however, there’s a few issues with this. For example, if my data contains only one class (it misses that data sample entirely), also I’ve tried getting this to work for batches of data, however, it doesn’t seem to broadcast. I did try using torch.func.vmap, however, vmap doesn’t support dynamical shapes!

Another thought I had was to use torch.where(data==1, data, 0), but that leaves me with numerous rows filled with zeros. Is it possible to delete rows that are only filled with zeros?

Any help would be greatly appreciated!

Yes, that’s possible with a direct comparison:

x = torch.tensor([0, 1, 1, 0])
idx0 = x == 0
idx1 = x == 1

x[idx0]
# tensor([0, 0])
x[idx1]
# tensor([1, 1])
1 Like