Right way to make one hot encoding for segmentation?

I have a target with size of NxHxW, where N is batch size. I want to make one hot encoding with the output size of NxCxHxW for image segmentation.

def one_hot(targets):    
    targets_extend=targets.clone()
    targets_extend.unsqueeze_(1) # convert to Nx1xHxW
    one_hot = torch.cuda.FloatTensor(targets_extend.size(0), C, targets_extend.size(2), targets_extend.size(3)).zero_()
    one_hot.scatter_(1, targets_extend, 1) 
    return one_hot

In the above code, I have used targets_extend=targets.clone(), because the targets will be used in the cross-entropy loss (only allows the size of NxHxW, instead of Nx1xHxW). More clears, the output of one hot encoding is used for BCE loss, while the targets used in the cross-entropy loss. Is my way (targets_extend=targets.clone()) correct in the making one hot encoding? Can I use targets_extend=targets.detach()? Thanks

I would stick with .clone, since .detach still uses the same underlying data.
Your inplace unsqueeze_ will therefore also unsqueeze targets:

x = torch.ones(10)
y = x.detach()
y.unsqueeze_(0)
print(x.shape)
> torch.Size([1, 10])
print(y.shape)
> torch.Size([1, 10])

May I ask, why you need the BCELoss?
It doesn’t seem like you have a multi-label setup or are you manipulating the one_hot target somehow afterwards?

1 Like

Thanks. I used it for computing the adversarial loss of multiple class ( likes 21 classes in voc). Meanwhile, the cross entropy is for segmentation loss. As I mentioned, the cross entropy expects the targets size of NxHxW, while BCELoss is NxCxHxW. So, your answer is using .clone, Am I right?

Yes, I would use clone and provide two different targets for the different criteria.

This solution is good when we have the target segmentations in the range of the index we are mapping to.
i.e if C classes, the intensity of the segmentation mask ranges from 1 to C.

However, how do we one-hot encode a segmentation mask with 5 classes but intensities in the range 0,255 with only 5 unique values in the mask?

Please let me know if the question is not clear.

@ptrblck
true_1_hot = torch.eye(num_classes)[targs.squeeze(1)]

how does this work

num_classes=4

shape of targs after squeeze is BS,H,W
resultant shape is

BS,H,W,4

I fail to understand this

torch.eye(num_classes) will create a tensor in the shape [num_classes, num_classes] with 1s in its diagonal. [targs.squeeze(1)] will then index this tensor “row-wise” and return the one-hot encoded tensor.
You would get the same output using F.one_hot(targs, num_classes=num_classes).

To dig into the indexing using multiple dimensions, check this example code snippet:

a = torch.arange(10)
b = torch.randint(0, 10, (16,))

# index with "flat" dimension
out_b = a[b]

# index with multiple dimensions
c = b.view(2, 2, 2, 2)
out_c = a[c]

# compare
print((out_b == out_c.view(-1)).all())
> tensor(True)