One-hot encoding of multi-class mask

Hello,
I am creating one-hot encoded labels from 3D Mask (with HxWxD shape) by the following code:

seg_3d = nib.load(seg_file)
seg = seg_3d.get_data() # fixme
labels =np.unique(seg) # [ 0  1  2  3  4  8 10 11 56]
num_labels = len(labels) # 9
segD = np.zeros((num_labels, seg.shape[0], seg.shape[1], seg.shape[2]))
for i in range(1, num_labels):  # this loop starts from label 1 to ignore background 0
    segD[i, :, :, :] = seg == labels[i]

Are there better and efficient ways to do that?
I tried torch.scatter_ method in PyTorch but couldn’t work it out. Moreover, I need the one-hot encoding in NumPy type, .scatter_ from PyTorch takes tensor type.
Any help?

These days there is a torch.nn.functional.one_hot that does what you need except that the label dimension is last, but one_hot + permute will get you what you need.

You can (almost) always use torch.from_numpy(array) and tensor.numpy() to get from and to numpy efficiently.

Best regards

Thomas

1 Like

Hello @tom

from torch.nn.functional import one_hot
seg = seg_3d.get_data() # 3D numpy array with 9 different values
num_labels = len(np.unique(seg)) # 9
seg = torch.tensor(seg).to(torch.int64)
seg_hot = one_hot(seg, num_labels)
print(seg_hot.shape, torch.unique(seg_hot))

with the codes I get the following error:

    seg_hot = one_hot(seg, num_labels)
RuntimeError: Class values must be smaller than num_classes.

Looks like they discussed the issue here: torch.nn.functional.one_hot should gracefully skip negative and out-of-range indices · Issue #45352 · pytorch/pytorch · GitHub
But found no alternates or solutions.
I think some of the class value e.g. 56, 11 or 10 are higher than the number of classes(9), that results in the error.

I think the problem you are seeing is from the relabeling. You probably would want to use np.unique with return_inverse and pass the output of that to one_hot.