Does indexing and slicing a batch lead to slowness?

What is the best way to select a subset of data from batch during forward pass based on a condition?

For example, the following seems to take twice the amount

ids = torch.squeeze(torch.nonzero(torch.ones(batch_size))).cuda() # some condition
subset_x= x[ids,]
subset_x = conv(subset_x)

compared to processing the whole batch of data

x = conv(x)

The way you are doing seems fine. Can you send a small snippet illustrating the twice slowdown?

Here is the modified mnist example to illustrate this.

My current guess is the cost of indexing adds up to a considerable amount if there several of such operations.