What is the best way to select a subset of data from batch during forward pass based on a condition?
For example, the following seems to take twice the amount
ids = torch.squeeze(torch.nonzero(torch.ones(batch_size))).cuda() # some condition
subset_x= x[ids,]
subset_x = conv(subset_x)
compared to processing the whole batch of data
x = conv(x)