# How to do channel selection

I have a tensor of shape [B,C,H,W].
I want a network which outputs a set of probabilities for each channel, i.e. a tensor of shape [B,C,H,W].
I then want to use the probabilities of each channel to select the best channel for each pixel, and output a tensor of shape [B,1,H,W].

So, this is what i have:

``````class ChannelDecision(nn.Module):
def __init__(self, inc):
super(ChannelDecision, self).__init__()
self.net    = nn.Sequential(nn.Conv2d(inc, inc, kernel_size=1, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(inc, inc, kernel_size=1, bias=True),
nn.Softmax(dim=1))

def forward(self, x):
probs       = self.net(x)
decisions   = torch.argmax(probs, dim=1, keepdim=True)
y = x[decisions] #This doesn't do what i want to do and requires an ENORMOUS allocation
return y

x = torch.randn(4,32,128,128)
y = n(x) #make this work please
``````

How can I do this?

Ok my work around is to have a soft decision, not a hard decision. So do this:

``````class ChannelDecision(nn.Module):
def __init__(self, inc):
super(ChannelDecision, self).__init__()
self.net    = nn.Sequential(nn.Conv2d(inc, inc, kernel_size=1, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(inc, inc, kernel_size=1, bias=True),
nn.Softmax(dim=1))

def forward(self, x):
probs = self.net(x)
x = x * probs
x = x.sum(1, keepdim=True)
return x
``````