How to normalize a ConcatDataset?

I’ve created a small code sample to calculate the mean and std of your dataset of the fly in case all images do not fit into your memory here.

After you’ve calculated the mean and std you can create a Dataset and use transform.Normalize to normalize the images:


transform = transforms.Normalize(mean=mean, std=std)

class MyDataset(Dataset):
    def __init__(self, data, transform=None):
        self.data = data
        self.transform = transform
        
    def __getitem__(self, index):
        x = self.data[index]
        if self.transform:
            x = self.transform(x)
            
        return x
    
    def __len__(self):
        return len(self.data)


dataset = MyDataset(data, transform=transform)

Let me know, if that works for you!

1 Like