How to normalize a ConcatDataset?

Hi all,
I have a dataset where each sample has 7 different channels. Currently I build the datasets for each of my 4 classes separately and then use a concatdataset to put them together. I need to perform a z-score normalization on the whole training set, separately for each channel - it looks like I want to use transforms.Normalize to do this, but I’m having trouble figuring out how.

Would the best practice be to subclass the concatenated dataset in someway to add this normalization? Also, is there an efficient way to get the mean/stddev for each channel once I have built the whole training set? Lastly, if I need to build my own Dataset (i.e., just subclass the generic Dataset), is there a way I can draw examples from the three dataset’s I’ve already made?
Thanks so much for any help!

I’ve created a small code sample to calculate the mean and std of your dataset of the fly in case all images do not fit into your memory here.

After you’ve calculated the mean and std you can create a Dataset and use transform.Normalize to normalize the images:

transform = transforms.Normalize(mean=mean, std=std)

class MyDataset(Dataset):
    def __init__(self, data, transform=None): = data
        self.transform = transform
    def __getitem__(self, index):
        x =[index]
        if self.transform:
            x = self.transform(x)
        return x
    def __len__(self):
        return len(

dataset = MyDataset(data, transform=transform)

Let me know, if that works for you!

1 Like

Thanks so much for the help! One follow up - I’m not sure how I can modify your code to calculate the mean and stddev seperately for each of my channels (i.e. every one of my batches will be shape [batchsz, 7, 20] and I’d like to normalize each of the 7 channels seperately). Any thoughts on this? Thanks again

1 Like

In the linked code snippet the mean and std are calculated for each channel, such that both estimates will contain 7 values.
If you want to calculate it separately for each channel, you could split the data in each channel and run the code.
Is there a reason you don’t want to calculate the mean and std for every channel in a single run?

No sorry, what you said should work then! I thought they were being calculated as single scalar values, not as vectors with a value for each channel - I must have changed something without realizing. Thanks so much for all the help!

1 Like

No worries! Let me know, if you get stuck somewhere or my code doesn’t produce the right results.