How does torchvision.transforms.Normalize work

Hi.
Whenever I Normalize my input images with its mean and std, and when i plot the images to visualize i get this message along with plots of images which are distorted in a way.
Clipping input data to the valid range for imshow with RGB data ([0…1] for floats or [0…255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0…1] for floats or [0…255] for integers).
If anyone could help me understand why this is happening it will be great.

download

1 Like

torchvision.transforms.Normalize will use the mean and std to standardize the inputs, so that they would have a zero mean and unit variance.

Your current library to show these images (probably matplotlib) will clip the values of these float image arrays to [0, 1], which will distort them.

Thank you very much @ptrblck . Yes i am using matplotlib to show the images. Getting distorted images , does it mean that I am doing something wrong ? , If not is there a way i could plot the images without them being distorted ?

The easiest way would be to plot them before normalizing.
However, if that’s not possible, you could also undo the normalization:

x = torch.empty(3, 224, 224).uniform_(0, 1)
mean = (0.5, 0.5, 0.5)
std = (0.5, 0.5, 0.5)
norm = transforms.Normalize(mean, std)
x_norm = norm(x) 

x_restore = x_norm * torch.tensor(std).view(3, 1, 1) + torch.tensor(mean).view(3, 1, 1)

print((x_restore - x).abs().max())
> tensor(0.)
2 Likes

Thank your very much @ptrblck. I will give this a try :slight_smile: