I have a tensor of shape [B,C,H,W].

I want a network which outputs a set of probabilities for each channel, i.e. a tensor of shape [B,C,H,W].

I then want to use the probabilities of each channel to select the best channel for each pixel, and output a tensor of shape [B,1,H,W].

So, this is what i have:

```
class ChannelDecision(nn.Module):
def __init__(self, inc):
super(ChannelDecision, self).__init__()
self.net = nn.Sequential(nn.Conv2d(inc, inc, kernel_size=1, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(inc, inc, kernel_size=1, bias=True),
nn.Softmax(dim=1))
def forward(self, x):
probs = self.net(x)
decisions = torch.argmax(probs, dim=1, keepdim=True)
y = x[decisions] #This doesn't do what i want to do and requires an ENORMOUS allocation
return y
x = torch.randn(4,32,128,128)
y = n(x) #make this work please
```

How can I do this?