Thanks for the link!
I’ve downloaded the BMMC dataset with the masks and it seems three different color codes are used for the masks as you already explained.
This code should load masks and convert these values to class indices:
class CustomDataset(Dataset):
def __init__(self, image_paths, target_paths): # initial logic happens like transform
self.image_paths = image_paths
self.target_paths = target_paths
self.transforms = transforms.ToTensor()
self.mapping = {
85: 0,
170: 1,
255: 2
}
def mask_to_class(self, mask):
for k in self.mapping:
mask[mask==k] = self.mapping[k]
return mask
def __getitem__(self, index):
image = Image.open(self.image_paths[index])
mask = Image.open(self.target_paths[index])
t_image = self.transforms(image)
mask = torch.from_numpy(np.array(mask))
mask = self.mask_to_class(mask)
return t_image, mask
def __len__(self): # return count of sample we have
return len(self.image_paths)
train_dataset = CustomDataset(train_image_paths, train_mask_paths)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=2)
for data, target in train_loader:
print(torch.unique(target))
You can of course change the mapping, e.g. if the pixel value 85 should be mapped to another class index.
Just make sure your classes start with 0 and end with num_classes-1
in case you would like to use a classification criterion like nn.CrossEntropyLoss
.