In the case of an image enhancement application, I would like to normalize each image of the batch independently before entering the network and then inverse normalize the output images using the statistics of their corresponding input image.
I tried to use InstanceNorm2d() but I’m not sure to understand how to use it in my case.
You could calculate the mean and std for each image and channel, normalize the images, perform your forward pass, and finally undo the normalization on the output:
Yeah thank you that’s indeed what I was looking for, except in my use case, the input is composed of two images (so the input size is [batch_size, 2*channels, h, w]) and I need to jointly normalize both images to get one mean and one std per channel.
So starting from your piece of code, how do I compute statistics not for each channel separately, but for 2 channels at a time ?