How to do channel selection

I have a tensor of shape [B,C,H,W].
I want a network which outputs a set of probabilities for each channel, i.e. a tensor of shape [B,C,H,W].
I then want to use the probabilities of each channel to select the best channel for each pixel, and output a tensor of shape [B,1,H,W].

So, this is what i have:

class ChannelDecision(nn.Module):
    def __init__(self, inc):
        super(ChannelDecision, self).__init__()
        self.net    = nn.Sequential(nn.Conv2d(inc, inc, kernel_size=1, bias=True),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(inc, inc, kernel_size=1, bias=True),
                                    nn.Softmax(dim=1))

    def forward(self, x):
        probs       = self.net(x)
        decisions   = torch.argmax(probs, dim=1, keepdim=True)
        y = x[decisions] #This doesn't do what i want to do and requires an ENORMOUS allocation
        return y

x = torch.randn(4,32,128,128)
y = n(x) #make this work please

How can I do this?

Ok my work around is to have a soft decision, not a hard decision. So do this:

class ChannelDecision(nn.Module):
    def __init__(self, inc):
        super(ChannelDecision, self).__init__()
        self.net    = nn.Sequential(nn.Conv2d(inc, inc, kernel_size=1, bias=True),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(inc, inc, kernel_size=1, bias=True),
                                    nn.Softmax(dim=1))

    def forward(self, x):
        probs = self.net(x)
        x = x * probs
        x = x.sum(1, keepdim=True)
        return x