Softmax for multiclass segmentation

Hello,

I have a tensor representing multi class semantic segmentation that is the output of my network. It is of the shape [B, C, H, W]

Where B is the batch size, C is the number of classes, H is the image height and W is the image width. Right now I want these to be a probability distribution over the number of classes per pixel (for each image in the batch), how would I accomplish this with torch.nn.Softmax.

Thanks!

You could use:

B, C, H, W = 2, 3, 4, 4
x = torch.randn(B, C, H, W)

act = nn.Softmax(dim=1)
out = act(x)
print(out.sum(1))

out = F.softmax(x, dim=1)
print(out.sum(1))

That makes sense, thank you!