I am creating one-hot encoded labels from 3D Mask (with
HxWxD shape) by the following code:
seg_3d = nib.load(seg_file) seg = seg_3d.get_data() # fixme labels =np.unique(seg) # [ 0 1 2 3 4 8 10 11 56] num_labels = len(labels) # 9 segD = np.zeros((num_labels, seg.shape, seg.shape, seg.shape)) for i in range(1, num_labels): # this loop starts from label 1 to ignore background 0 segD[i, :, :, :] = seg == labels[i]
Are there better and efficient ways to do that?
torch.scatter_ method in PyTorch but couldn’t work it out. Moreover, I need the one-hot encoding in NumPy type,
.scatter_ from PyTorch takes tensor type.