Context: I am doing image segmentation using Pytorch, before feed the training data to the network, I need to do the normalisation
My image size is 256x256x3, and my mask size is 256x256x3
I have a TrainDataset class, and my sample is a dict type
for my image, I should use: sample['image']
for my image and sample['mask']
for the mask
The Question is: How can I do the normalization for a dict type dataset
I put my TrainDataset class below for information
class TrainDataset(Dataset):
"""Training dataset with mask image on gray scale/RGB"""
def __init__(self, train_dir, semantic_dir, transform=None):
"""
Args:
train_dir (string): Directory with training images
transform (callable): Optional transform to be applied on a sample
semantic_dir (string): Directory with semantic segmentation training image
"""
self.train_dir = train_dir
self.transform = transform
self.semantic_dir = semantic_dir
def __len__(self):
return len(os.listdir(self.train_dir))
def __getitem__(self, idx):
img_name = os.path.join(self.train_dir, os.listdir(self.train_dir)[idx])
semantic_name = os.path.join(self.semantic_dir, os.listdir(self.semantic_dir)[idx])
image = io.imread(img_name)
semantic = io.imread(semantic_name)
sample = {'image':image, 'semantic':semantic}
if self.transform:
sample = self.transform(sample)
return sample
composed = transforms.Compose([RandomCrop(256), ToTensor()])
transformed_dataset = TrainDataset(train_dir=train, semantic_dir =train_mask, transform = composed)
dataloader = DataLoader(transformed_dataset, batch_size=4, shuffle = False, num_workers = 4)
for i_batch, (inputs, target) in enumerate(train_loader):
print(i_batch, sample_batched['image'].size(), sample_batched['semantic'].size())
print(sample_batched['image'].dtype)
if i_batch == 3:
break
Here is the result for the data size:
0 torch.Size([4, 3, 256, 256]) torch.Size([4, 3, 256, 256])
torch.uint8
1 torch.Size([4, 3, 256, 256]) torch.Size([4, 3, 256, 256])
torch.uint8
2 torch.Size([4, 3, 256, 256]) torch.Size([4, 3, 256, 256])
torch.uint8
3 torch.Size([4, 3, 256, 256]) torch.Size([4, 3, 256, 256])
torch.uint8
Thank you in advance.