I want to remove a few channels depending on some criterion. For example, I may remove based on probabilities. If my probability vector is [ 0.5, 0.3, 0.1, 0.1] corresponding to four channels. How can I keep top 50% channels and remove all others?
Use torch.multinormial to sample the channel index to keep.
I think I didn’t explain the question correctly. So let me rephrase my question. I have a set of inputs X with shape (B, H, W, C) and a probability vector p of length same as the number of channels in X. Each channel is associated with one probability. I want to keep only 50% of the channels having the high probabilities. The expected output should be of a shape (B, H, W, C/2). Is there a way to achieve this efficiently?
Each channel has therefore it’s own probability, i.e. the probabilities for each channel do not have to sum to 1, is that correct?
A valid probability tensor would also be:
torch.tensor([0.9, 0.9, 0.9, 0.9])?
Well you could use
torch.bernoulli to sample the channels, but you could end up with 0 to 4 valid channels.
How would you like to keep 50% of the channels having the high probabilities?
If I understand your use case correctly, you could try to sample many times until you get two channels:
probs = torch.tensor([0.5, 0.3, 0.1, 0.1]) channel_idx = torch.bernoulli(probs) while channel_idx.sum() != torch.tensor([2.]): channel_idx = torch.bernoulli(probs)
Would that work or am I misunderstanding the issue?
Thanks @ptrblck. One more doubt and that is, once I have my probability
probs= torch.tensor([0.5, 0.3, 0.1, 0.1]), I want to keep channel 0 and channel 1 as it these channels have the maximum probability. How do I do the slicing?
Probably I misunderstood your use case. I thought you wanted to sample using the probabilities.
In case you just want to keep the two channels with the highest probabilities, you could use a mask or slice the input. This depends on your current use case.
batch_size = 1 c, h, w = 4, 5, 5 x = torch.randn(batch_size, c, h, w) probs= torch.tensor([0.5, 0.3, 0.1, 0.1]) _, idx = torch.topk(probs, 2) # Slicing modelB = nn.Conv2d(2, 1, 3, 1, 1, bias=False) weight = modelB.weight.clone() outputB = modelB(x[:, idx]) # Using a mask modelA = nn.Conv2d(c, 1, 3, 1, 1, bias=False) # Set weights to "valid" channels to compare the outputs with torch.no_grad(): modelA.weight[:, idx] = weight mask = torch.zeros(batch_size, c, 1, 1) mask[:, idx] = 1. outputA = modelA(x * mask) torch.allclose(outputA, outputB) > True
I got some answer related to this question. I don’t remember the source but thanks whoever has given the answer.
b,c,h,w = x.size() # let the size be(2,128,32,32) slice_idx = [[i] for i in range(b)] #idx is a tensor of size (2,32) containg the the indices to keep from 2nd dimension (c) #the output should have size (2,32,32,32) out= x[slice_idx, idx]
I want to ask whether what I’m doing Is correct?
Hi, This method is exactly what the previous post is pointing.