In the case of an image enhancement application, I would like to normalize each image of the batch independently before entering the network and then inverse normalize the output images using the statistics of their corresponding input image.
I tried to use InstanceNorm2d() but I’m not sure to understand how to use it in my case.
Any idea ?
You could calculate the mean and std for each image and channel, normalize the images, perform your forward pass, and finally undo the normalization on the output:
batch_size = 10
channels = 3
h, w = 24 ,24
images = torch.randn(batch_size, channels, h, w)
im_mean = images.view(batch_size, channels, -1).mean(2).view(batch_size, channels, 1, 1)
im_std = images.view(batch_size, channels, -1).std(2).view(batch_size, channels, 1, 1)
images_norm = (images - im_mean) / im_std
model = nn.Conv2d(3, 3, 3, 1, 1)
output = model(images_norm)
output = (output * im_std) + im_mean
Let me know, if this works for you or if I’ve misunderstood your use case.
Yeah thank you that’s indeed what I was looking for, except in my use case, the input is composed of two images (so the input size is [batch_size, 2*channels, h, w]) and I need to jointly normalize both images to get one mean and one std per channel.
So starting from your piece of code, how do I compute statistics not for each channel separately, but for 2 channels at a time ?