Efficient way to one hot encode whole image for semantic segmentation

Hi, I am working on semantic segmentation. Currently, my labels for training are of size [batch_size, channels, h, w] where channels = 1, each pixel is either 0 (pixel belongs to background) or 1 (pixel belongs to target_class). What I need to do is to convert my [batch_size, 1, h,w]-label into a [batch_size, 2, h, w]-label where each pixel with 0 is converted to [1,0] and each 1 to [0,1].
Basicially one-hot encode each pixel.

I don’t want to iterate through every pixel by for loop since this would take way too long.

Can someone help me out? I am sure there is a faster way to do this?


one_hot_label = torch.cat((1 - current_label, current_label), dim=1)


  1. (1 - current_label, current_label) would be (1,0) when current label is 0; and it would be (0,1) when current label is 1.
  2. dim=1 means to operate on the 2nd dimension (which is channels)
1 Like