Smush multi-channel tensor into one image channel with maximums

Hello there,

I have programmed a multiclass segmentation model, everything works fine. My output for the tests is a tensor shaped 668x388x388 thus the output contains 668 different classes. So my goal is to get a tensor (or numpy array) out of this shaped like 1x388x388 to give it to matplotlib.imshow and thus I provided a softmax, every value is between 0 and 1. I need to pick for each pixel in the resulting array the highest one out of the 668 channels…

Is there a method for this?

To create the predictions containing the class index associated with the highest probability (or logit) you could use:

preds = torch.argmax(output, dim=0) # assuming output has the mentioned shape [nb_classes, height, weight]

Thanks, also a 3DMaxPool should be possible : D