Converting 3D label to one-hot encoding

Hi, I have a 3d label, and I would like to convert it to one-hot encoding, for example: (C, D, H, W), C=1, after one-hot encoding, converting to (2, D, H, W). Is there any quick solution to this?

labels = np.unique(img)
one_hot = np.zeros((2, D, H, W))
for i in range(0, num_labels):
     seg_ =  img == labels[i]
     one_hot[i, :, :, :] = seg_[:, 0:img.shape[1], 0:img.shape[2], 0:img.shape[3]]