How to keep only top k percent values

Supports that x is a tensor with shape = [batch_size, channel_number, height, width], some may be positive value and others may be 0. And I only want to keep the value of top 10% positive values of each channel and make others 0. But the problem is that the number of pixels of positive value in each channel is different, so I wonder how to solve it efficiently?

1 Like

It shouldn’t matter, if some values are positive and others zero, since the result tensor would also contain zeros. So even if the top 10% includes zeros (if not enough positive values were found), you wouldn’t change the output, would you?

If that’s the case, this code should work:

# Setup
x = torch.randn(2, 3, 10, 10)
orig_shape = x.size()

# Reshape and calculate positions of top 10%
x = x.view(x.size(0), x.size(1), -1)
nb_pixels = x.size(2)
ret = torch.topk(x, k=int(0.1*nb_pixels), dim=2)
ret.indices.shape

# Scatter to zero'd tensor
res = torch.zeros_like(x)
res.scatter_(2, ret.indices, ret.values)
res = res.view(*orig_shape)
3 Likes

Thanks for your answer. Yes, it seems that the number of positive pixel actually doesn’t matter! Your answer helps me a lot, thanks again!