Hi I am having problem while converting rgb mask of shape [224,224,3] to mask of shape [224,224,3]. I have attached the code below.
I am getting masks of shape [224,224,classes] but lose information of classes in channels only one channel has some mask will others don’t
Code:
class CamVid_Dataset():
def __init__(self, img_pth, mask_pth, transform):
self.img_pth = img_pth
self.mask_pth = mask_pth
self.transform = transform
all_imgs = os.listdir(self.img_pth)
all_masks = os.listdir(self.mask_pth)
self.total_imgs = natsort.natsorted(all_imgs)
self.total_masks = natsort.natsorted(all_masks)
def __len__(self):
return len(self.total_imgs)
def __getitem__(self, idx):
img_loc = os.path.join(self.img_pth, self.total_imgs[idx])
image = Image.open(img_loc).convert("RGB")
tensor_image = self.transform(image)
mask_loc = os.path.join(self.mask_pth, self.total_masks[idx])
mask = Image.open(mask_loc).convert("RGB")
tensor_mask = self.transform(mask)
tensor_mask = rgb_to_mask(np.array(tensor_mask).transpose(1,2,0), id2code)
return tensor_image, tensor_mask
#Define transforms for the training data and validation data
train_transforms = transforms.Compose([ transforms.Resize((224,224)), transforms.ToTensor()])
Pass transform here-in
train_data = CamVid_Dataset(img_pth = path + 'train/', mask_pth = path + 'train_labels/', transform = train_transforms)
Data loaders
trainloader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=False)
inputs, mask = next(iter(trainloader))
RGB2MASK:
def rgb_to_mask(img, color_map):
'''
Converts a RGB image mask of shape [batch_size, h, w, 3] to Binary Mask of shape [batch_size, classes, h, w]
Parameters:
img: A RGB img mask
color_map: Dictionary representing color mappings
returns:
out: A Binary Mask of shape [batch_size, classes, h, w]
'''
num_classes = len(color_map)
shape = img.shape[:2]+(num_classes,)
out = np.zeros(shape, dtype=np.int8)
for i, cls in enumerate(color_map):
out[:,:,i] = np.all(img.reshape( (-1,3) ) == color_map[i], axis=1).reshape(shape[:2])
return out