Method to create one-hot encoding in data.utils.dataset

Hello All,

I am trying to create dataset for multi-class segmentation. Image sizes are 48,64,64 and segmentation masks have 9 different labels.

imgnp = imgnp[None, ...]
segD = np.zeros((num_labels, 48, 64, 64))
for i in range(0, num_labels):  # this loop starts from label 1
    seg_one = segnp == labels[i]
    segD[i, :, :, :] = seg_one[0:segnp.shape[0], 0:segnp.shape[1], 0:segnp.shape[2]]
imgD = imgnp.astype('float32')
segD = segD.astype('float32')
return imgD, segD

The output is image of shape 1,48,64,64 and one-hot encoded binary segmentation mask with 9 channels for 9 tissue labels: 9,48,64,64.

My ques: is that an efficient way to create the dataset or there are any mistakes or other better ways to do so.

*N.B.: the chunk of the code is from def __getitem__
Thanks in advance.

Use scatter

def create_one_hot(x: Tensor, n_class: int):
  :param x: [B, D1, D2, D3] 0 <= x.min() && x.max() < C
  :return: [B, C, D1, D2, D3]
  B, D1, D2, D3 = x.shape
  out = torch.zeros((B, n_class, D1, D2, D3), dtype=torch.float, device=x.device)
  x = x[:, None, :, :, :]
  out.scatter_(dim=1, index=x, src=torch.ones_like(x, dtype=out.dtype))
  return out

For creating a one-hot tensor, I have implemented five functions…

Hope it works for you.