Hi All,
Let’s say I have a Tensor of shape [batch_size, num_input]
where I have 2 classes for my input data. How could I get the indices for each of the classes?
batch_size=100
num_input=4
data = torch.randint(0,2,size=(batch_size,num_input)) #fake data has 0 and 1 as the two classes.
Let’s say I had one example that was [0,1,1,0]
, I’d want a function that returns [0,3]
for class_id = 0 or [1,2]
for class_id = 1. This is with the idea I could have idx0
and idx1
and then simply do data0 = data[idx0]
and data1 = data[idx1]
. Is this possible?
I’ve had a look at doing (data==0).nonzero(as_tuple=True)
, however, there’s a few issues with this. For example, if my data contains only one class (it misses that data sample entirely), also I’ve tried getting this to work for batches of data, however, it doesn’t seem to broadcast. I did try using torch.func.vmap
, however, vmap
doesn’t support dynamical shapes!
Another thought I had was to use torch.where(data==1, data, 0)
, but that leaves me with numerous rows filled with zeros. Is it possible to delete rows that are only filled with zeros?
Any help would be greatly appreciated!