How do I get single-channel output when doing semantic segmentation?

Dear:
The problem is that, when i implement the semantic segmetation, the output shape of the semantic segmentation network is [batch_size, num_classes, width, height],it is encoded like one_hot, but i just want to directly get the output which shape is [batch_size, width, height],in the output each pixel has the value 0,1,2…(indicate the predicted class index), what can i do to get the output’s shape like [batch_size, width, height] (without postprocess like argmax() ).
please help me, thanks all.

1 Like

In a semantic segmentation use case the output of your model would represent logits and would thus not be one-hour encoded. To get the predicted class label you would use argmax. Since this operation is not differentiable it’s not included in the forward pass of the model.

1 Like

Hello ptrblck:
Thanks for your reply.

For some reasons, such as deploying neural networks on devices, the .pth fileI will be converted to anonther format like .dlc(made by qualcomm), what i get is just the result infered by the network, so i have two options: One is to get the output of the shape [batch_size, num_classes, width, height] by the network, and then I need to implement the argmax() function in a language like c++ to get the result of semantic segmentation. The other is that my network output is [batch_size, width, height] so I don’t need to implement the argmax() function myself.

in your reply do you mean that when I use the multi-classes semantic segmentation network, its output shape must be [batch size, num_classes, width, height] ? and also i need use argmax() function to get the final single-channel result

Thank you.

Yes, that’s the standard workflow. I’m not familiar with the .dlc format, but in case you can call a custom method of your deployed model, you could also create an e.g. predict method, which could call the forward method with the additional torch.argmax internally. Once deployed, you could then use preds = model.predict(input).

Hello ptrblck:
Thank you for your kind reply on above question.

I think I understand what you say, you want change the predict fucntion by adding argmax() so that when I run predict function I get the wanted shape.

Except for the method you proposed,can I try this method : I train the model without using argmax() in my forward propagation function( cause it is not differentiable), when I use the weight file to do an inference, I modify the model( like adding the argmax() to forward propagation function) so that i can get the result with shape [batch_size, width, height], but I am not sure whether it works.

And also I have another question, if I modify the model( adding argmax() to model) when I implemt inference, i think there may be some probelems when I load the weight file(error may be like : incompatible keyword etc.)

Yes, this would be my proposed approach if you can’t add the argmax operation in your deployment setup. Manipulating the forward would work or probably the cleaner approach would be to write a custom predict method which adds the argmax operation to it as described before. You could alsoreturn two outputs in the forward: the logits as well as the class predictions and could use one of them corresponding to training or inference. The “best” approach depends on your workflow.

No, you should not see any state_dict mismatches, as torch.argmax is a purely functional call, won’t have any parameters, and will thus not be stored in the state_dict.

Hello ptrblck:

I got what you say. I will try these ways later.

Thank you very much. Have a nice day :grinning: