Remove a few channels from input

I want to remove a few channels depending on some criterion. For example, I may remove based on probabilities. If my probability vector is [ 0.5, 0.3, 0.1, 0.1] corresponding to four channels. How can I keep top 50% channels and remove all others?

Use torch.multinormial to sample the channel index to keep.

I think I didn’t explain the question correctly. So let me rephrase my question. I have a set of inputs X with shape (B, H, W, C) and a probability vector p of length same as the number of channels in X. Each channel is associated with one probability. I want to keep only 50% of the channels having the high probabilities. The expected output should be of a shape (B, H, W, C/2). Is there a way to achieve this efficiently?

Each channel has therefore it’s own probability, i.e. the probabilities for each channel do not have to sum to 1, is that correct?
A valid probability tensor would also be: torch.tensor([0.9, 0.9, 0.9, 0.9])?

Well you could use torch.bernoulli to sample the channels, but you could end up with 0 to 4 valid channels.
How would you like to keep 50% of the channels having the high probabilities?

If I understand your use case correctly, you could try to sample many times until you get two channels:

probs = torch.tensor([0.5, 0.3, 0.1, 0.1])
channel_idx = torch.bernoulli(probs)
while channel_idx.sum() != torch.tensor([2.]):
    channel_idx = torch.bernoulli(probs)

Would that work or am I misunderstanding the issue?

Thanks @ptrblck. One more doubt and that is, once I have my probability probs= torch.tensor([0.5, 0.3, 0.1, 0.1]), I want to keep channel 0 and channel 1 as it these channels have the maximum probability. How do I do the slicing?

Probably I misunderstood your use case. I thought you wanted to sample using the probabilities.
In case you just want to keep the two channels with the highest probabilities, you could use a mask or slice the input. This depends on your current use case.

batch_size = 1
c, h, w = 4, 5, 5
x = torch.randn(batch_size, c, h, w)

probs= torch.tensor([0.5, 0.3, 0.1, 0.1])
_, idx = torch.topk(probs, 2)

# Slicing
modelB = nn.Conv2d(2, 1, 3, 1, 1, bias=False)
weight = modelB.weight.clone()
outputB = modelB(x[:, idx])

# Using a mask
modelA = nn.Conv2d(c, 1, 3, 1, 1, bias=False)

# Set weights to "valid" channels to compare the outputs
with torch.no_grad():
    modelA.weight[:, idx] = weight

mask = torch.zeros(batch_size, c, 1, 1)
mask[:, idx] = 1.
outputA = modelA(x * mask)

torch.allclose(outputA, outputB)
> True
1 Like

I got some answer related to this question. I don’t remember the source but thanks whoever has given the answer.

b,c,h,w = x.size() # let the size be(2,128,32,32)
slice_idx = [[i] for i in range(b)]
#idx is a tensor of size (2,32) containg the the indices to keep from 2nd dimension (c)
#the output should have size (2,32,32,32)
out= x[slice_idx, idx]

I want to ask whether what I’m doing Is correct?

1 Like

Hi, This method is exactly what the previous post is pointing.