Let’s say I have a Tensor of shape
[batch_size, num_input] where I have 2 classes for my input data. How could I get the indices for each of the classes?
batch_size=100 num_input=4 data = torch.randint(0,2,size=(batch_size,num_input)) #fake data has 0 and 1 as the two classes.
Let’s say I had one example that was
[0,1,1,0], I’d want a function that returns
[0,3] for class_id = 0 or
[1,2] for class_id = 1. This is with the idea I could have
idx1 and then simply do
data0 = data[idx0] and
data1 = data[idx1]. Is this possible?
I’ve had a look at doing
(data==0).nonzero(as_tuple=True), however, there’s a few issues with this. For example, if my data contains only one class (it misses that data sample entirely), also I’ve tried getting this to work for batches of data, however, it doesn’t seem to broadcast. I did try using
vmap doesn’t support dynamical shapes!
Another thought I had was to use
torch.where(data==1, data, 0), but that leaves me with numerous rows filled with zeros. Is it possible to delete rows that are only filled with zeros?
Any help would be greatly appreciated!