How to keep only top k percent values

Supports that x is a tensor with shape = [batch_size, channel_number, height, width], some may be positive value and others may be 0. And I only want to keep the value of top 10% positive values of each channel and make others 0. But the problem is that the number of pixels of positive value in each channel is different, so I wonder how to solve it efficiently?

It shouldn’t matter, if some values are positive and others zero, since the result tensor would also contain zeros. So even if the top 10% includes zeros (if not enough positive values were found), you wouldn’t change the output, would you?

If that’s the case, this code should work:

# Setup
x = torch.randn(2, 3, 10, 10)
orig_shape = x.size()

# Reshape and calculate positions of top 10%
x = x.view(x.size(0), x.size(1), -1)
nb_pixels = x.size(2)
ret = torch.topk(x, k=int(0.1*nb_pixels), dim=2)

# Scatter to zero'd tensor
res = torch.zeros_like(x)
res.scatter_(2, ret.indices, ret.values)
res = res.view(*orig_shape)

Thanks for your answer. Yes, it seems that the number of positive pixel actually doesn’t matter! Your answer helps me a lot, thanks again!