[Image Segmentation] Preprocessing of multilabels in 2D training data before training model

Hi,

I have 2D training data (Image, label) of shape 496x512. Each Image can have multi labels (in total 4 classes as 0,1,2, or 3).

Here what I have until now:
From the DataSet:
BEFORE TRANSFORMSTION:
→ label shape from OCTDataset: (496, 512)
→ img shape from OCTDataset: (496, 512)
→ image dtype from OCTDataset: float32, label dtype from OCTDataset: float32

AFTER TRANSFORMSTION:
→ shape just after tranform of the label: torch.Size([1, 512, 512])
–>shape just after tranform of the img: torch.Size([1, 512, 512])
–>image dtype after transform : torch.float32, label dtype after transform : torch.float32

FROM DATALOADER:
image shape: torch.Size([2, 1, 512, 512]), label shape: torch.Size([2, 1, 512, 512])
image dtype: torch.float32, label dtype: torch.float32

My Question:
In total there can be 4 classes (including background) but how should I encode these classes? With current code, my label from dataloader is in shape of torch.Size([2, 1, 512, 512]) but I guess it should be torch.Size([2, classes=4, 512, 512]). But how Should I get it?

Thanks!

I assume you are working on a multi-label segmentation use case where each pixel can belong to 0, 1, or more classes.
In this case the target should be multi-hot encoded where a 0 and 1 indicates if the class is inactive or active in the current pixel, respectively.
Based on your current label shape it seems you are using a single channel only so are you using class indices?
If so, than your use case would sound like a multi-class segmentation where each pixel belongs to one class only?

My Task:
Actually, original training data is in 3D with shape 49x496x512 for both Images and labels. And I have to create 2D images and labels out of it and then train a UNet to segment each class (3 eye fluids).

here is the code for converting 2D from slices of 3D data:

# give list of all oct images and references from the provided ROOT folder
def preprocess_oct_images_in_numpy(dir:str):
  references = list() #only for training data
  oct_images =list()
  for subdir, dirs, files in os.walk(dir):
    for file in files:
      filepath = subdir + os.sep + file
      if filepath.endswith("reference.mhd"):
        references.append(filepath)     
      elif filepath.endswith("oct.mhd"):
        oct_images.append(filepath)     

  references = sorted(references)
  oct_images = sorted(oct_images)
  
  updated_image_list =[]
  updated_label_list =[] #only for training data
  
  # creating 2D image (496x512) from each slice of 3D Image (49x496x512)
  # For image
  for data in range(len(oct_images)):
    image_numpy_3D = sitk.GetArrayFromImage(sitk.ReadImage(oct_images[data], sitk.sitkFloat32)) # order z,y,x
    for idx in range(image_numpy_3D.shape[0]):
      updated_image_list.append(image_numpy_3D[idx,:,:])

  # for label
  for index in range(len(references)):
    label_numpy_3D = sitk.GetArrayFromImage(sitk.ReadImage(references[index], sitk.sitkFloat32)) # order z,y,x
    for idx in range(label_numpy_3D.shape[0]):
      updated_label_list.append(label_numpy_3D[idx,:,:])

  return updated_image_list, updated_label_list

This is what one of my training data (image and label) before and after transform (in 2D) look like:

Thanks so here I have created multi one hot encoding for labels within the dataset class (I have 2nd data set class as well for tranformation) and with this I get label shape in 4x496x512 and after transform (with padding) it converted to shape 4x512x512

class OCTDataset(Dataset):
  "OCT Scan data set"
  def __init__(self, images, references):
    super(OCTDataset, self).__init__()
    self.images = images
    self.references = references

  def __len__(self):
    return len(self.images)

  def __getitem__(self, idx):
    img = self.images[idx]
    label = self.references[idx] 

    # manual multi one hot encoding
    new_label = np.zeros((4, 496, 512),dtype=np.float32) # 4 here is 4 classes including background and 3 eye fluids

    for iRow in range(label.shape[0]):
      for iCol in range(label.shape[1]):
        if label[iRow,iCol] == 1:
          new_label[iRow,iCol,1] = 1
        if label[iRow,iCol] == 2:
          new_label[iRow,iCol,2] = 2
        if label[iRow,iCol] == 3:
          new_label[iRow,iCol,3] = 3

    return img, new_label

How can I check if its not multi class segmentation and not multilabel segmentation?
May be in 2D it seems like multi class but when one have to rebuild from 2D to 3D again in that case may be its multi label segmentation (because of overlapping of identified fluid regions), is it correct?

Many thanks!

As previously described the main difference would be: does each “element” (pixel, voxel, sample) belong to only one class (multi-class segmentation/classification) or can it belong to zero, one, or multiple classes (multi-label segmentation/classification).

I don’t quite understand your label creation, as you seem to index the spatial dimensions with the class index?

    # manual multi one hot encoding
    new_label = np.zeros((4, 496, 512),dtype=np.float32) # 4 here is 4 classes including background and 3 eye fluids

    for iRow in range(label.shape[0]):
      for iCol in range(label.shape[1]):
        if label[iRow,iCol] == 1:
          new_label[iRow,iCol,1] = 1
        if label[iRow,iCol] == 2:
          new_label[iRow,iCol,2] = 2
        if label[iRow,iCol] == 3:
          new_label[iRow,iCol,3] = 3

It seems dim1 and dim2 are the spatial dimensions, while you are indexing dim0 as the spatial dim.

you mean like this?

class OCTDataset(Dataset):
  "OCT Scan data set"
  def __init__(self, images, references):
    super(OCTDataset, self).__init__()
    self.images = images
    self.references = references

  def __len__(self):
    return len(self.images)

  def __getitem__(self, idx):
    img = self.images[idx]
    label = self.references[idx] 

    # # manual multi one hot encoding
    new_label = np.zeros((4, label.shape[0], label.shape[1]),dtype=np.float32) # 4 here is 4 classes including background and 3 eye fluids

    for iRow in range(label.shape[0]):
      for iCol in range(label.shape[1]):
        if label[iRow,iCol] == 1:
          new_label[1,iRow,iCol] = 1
        if label[iRow,iCol] == 2:
          new_label[2, iRow,iCol] = 1
        if label[iRow,iCol] == 3:
          new_label[3, iRow,iCol] = 1
    return img, new_label