I want to remove a few channels depending on some criterion. For example, I may remove based on probabilities. If my probability vector is [ 0.5, 0.3, 0.1, 0.1] corresponding to four channels. How can I keep top 50% channels and remove all others?

Use torch.multinormial to sample the channel index to keep.

I think I didn’t explain the question correctly. So let me rephrase my question. I have a set of inputs X with shape (B, H, W, C) and a probability vector p of length same as the number of channels in X. Each channel is associated with one probability. I want to keep only 50% of the channels having the high probabilities. The expected output should be of a shape (B, H, W, C/2). Is there a way to achieve this efficiently?

Each channel has therefore it’s own probability, i.e. the probabilities for each channel do not have to sum to 1, is that correct?

A valid probability tensor would also be: `torch.tensor([0.9, 0.9, 0.9, 0.9])`

?

Well you could use `torch.bernoulli`

to sample the channels, but you could end up with 0 to 4 valid channels.

How would you like to keep 50% of the channels having the high probabilities?

If I understand your use case correctly, you could try to sample many times until you get two channels:

```
probs = torch.tensor([0.5, 0.3, 0.1, 0.1])
channel_idx = torch.bernoulli(probs)
while channel_idx.sum() != torch.tensor([2.]):
channel_idx = torch.bernoulli(probs)
```

Would that work or am I misunderstanding the issue?

Thanks @ptrblck. One more doubt and that is, once I have my probability `probs= torch.tensor([0.5, 0.3, 0.1, 0.1])`

, I want to keep channel 0 and channel 1 as it these channels have the maximum probability. How do I do the slicing?

Probably I misunderstood your use case. I thought you wanted to sample using the probabilities.

In case you just want to keep the two channels with the highest probabilities, you could use a mask or slice the input. This depends on your current use case.

```
batch_size = 1
c, h, w = 4, 5, 5
x = torch.randn(batch_size, c, h, w)
probs= torch.tensor([0.5, 0.3, 0.1, 0.1])
_, idx = torch.topk(probs, 2)
# Slicing
modelB = nn.Conv2d(2, 1, 3, 1, 1, bias=False)
weight = modelB.weight.clone()
outputB = modelB(x[:, idx])
# Using a mask
modelA = nn.Conv2d(c, 1, 3, 1, 1, bias=False)
# Set weights to "valid" channels to compare the outputs
with torch.no_grad():
modelA.weight[:, idx] = weight
mask = torch.zeros(batch_size, c, 1, 1)
mask[:, idx] = 1.
outputA = modelA(x * mask)
torch.allclose(outputA, outputB)
> True
```

I got some answer related to this question. I don’t remember the source but thanks whoever has given the answer.

```
b,c,h,w = x.size() # let the size be(2,128,32,32)
slice_idx = [[i] for i in range(b)]
#idx is a tensor of size (2,32) containg the the indices to keep from 2nd dimension (c)
#the output should have size (2,32,32,32)
out= x[slice_idx, idx]
```

I want to ask whether what I’m doing Is correct?

Hi, This method is exactly what the previous post is pointing.