Hi All,

Let’s say I have a Tensor of shape `[batch_size, num_input]`

where I have 2 classes for my input data. How could I get the indices for each of the classes?

```
batch_size=100
num_input=4
data = torch.randint(0,2,size=(batch_size,num_input)) #fake data has 0 and 1 as the two classes.
```

Let’s say I had one example that was `[0,1,1,0]`

, I’d want a function that returns `[0,3]`

for class_id = 0 or `[1,2]`

for class_id = 1. This is with the idea I could have `idx0`

and `idx1`

and then simply do `data0 = data[idx0]`

and `data1 = data[idx1]`

. Is this possible?

I’ve had a look at doing `(data==0).nonzero(as_tuple=True)`

, however, there’s a few issues with this. For example, if my data contains only one class (it misses that data sample entirely), also I’ve tried getting this to work for batches of data, however, it doesn’t seem to broadcast. I did try using `torch.func.vmap`

, however, `vmap`

doesn’t support dynamical shapes!

Another thought I had was to use `torch.where(data==1, data, 0)`

, but that leaves me with numerous rows filled with zeros. Is it possible to delete rows that are only filled with zeros?

Any help would be greatly appreciated!