Softmax to one hot

I have a output tensor from a semantic segmentation network of size (21,512,512) where for each pixel there is a softmax probability vector. How can I convert it into output of same shape but with each pixel now having a one hot encoding?

Are you sure you need to convert your output to one-hot? Most loss functions take the class probabilities as inputs.

If you do need to do this however, you can take the argmax for each pixel, and then use scatter_.

import torch
probs = torch.randn(21, 512, 512)
max_idx = torch.argmax(probs, 0, keepdim=True)
one_hot = torch.FloatTensor(probs.shape)
one_hot.zero_()
one_hot.scatter_(0, max_idx, 1)

Note that you will not be able to calculate gradients w.r.t. indices. If you want to do this, there are a few methods, e.g. using something like REINFORCE, the Gumbel-Softmax, etc.

6 Likes