I was messing around a bit with torchvision.transforms.Normalize
and successfully identified a bug in my code where I had tried to apply Normalize
to a batch of images, i.e. 4D tensor of dimensions (batch_size x channels x width x height) instead of a single image, i.e. 3D tensor of dimensions (channels x width x height). Then I looked at the source code, I was surprised that no exception was thrown, because I identified a test whether the input was 3D. Specifically, when I import torchvision.transforms.Normalize
, no exception is thrown. However, then I copy the code from source to define all components myself, it throws the exception as suspected. What am I missing?
Source: https://github.com/pytorch/vision/blob/master/torchvision/transforms.py
import torch
from torchvision.transforms import Normalize
# Define a random batch of images
batch_size, channels, width, height = 100, 3, 32, 32
batch_of_images = torch.randn(batch_size, channels, width, height)
# Use imported normalisation
means = list(batch_of_images.transpose(0,1).contiguous().view(channels,-1).mean(dim=1))
stds = list(batch_of_images.transpose(0,1).contiguous().view(channels,-1).std(dim=1))
imported_normalise = Normalize(means, stds)
wrong_normalised_batch = imported_normalise(batch_of_images)
# Define yourself according to code in source
def _is_tensor_image(img):
return torch.is_tensor(img) and img.ndimension() == 3
def normalize(tensor, mean, std):
"""Normalize a tensor image with mean and standard deviation.
See ``Normalize`` for more details.
Args:
tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
mean (sequence): Sequence of means for each channel.
std (sequence): Sequence of standard deviations for each channely.
Returns:
Tensor: Normalized Tensor image.
"""
if not _is_tensor_image(tensor):
raise TypeError('tensor is not a torch image.')
# TODO: make efficient
for t, m, s in zip(tensor, mean, std):
t.sub_(m).div_(s)
return tensor
class Normalize(object):
"""Normalize an tensor image with mean and standard deviation.
Given mean: ``(M1,...,Mn)`` and std: ``(M1,..,Mn)`` for ``n`` channels, this transform
will normalize each channel of the input ``torch.*Tensor`` i.e.
``input[channel] = (input[channel] - mean[channel]) / std[channel]``
Args:
mean (sequence): Sequence of means for each channel.
std (sequence): Sequence of standard deviations for each channel.
"""
def __init__(self, mean, std):
self.mean = mean
self.std = std
def __call__(self, tensor):
"""
Args:
tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
Returns:
Tensor: Normalized Tensor image.
"""
return normalize(tensor, self.mean, self.std)
# now try it
selfdefined_normalise = Normalize(means, stds)
selfdefined_normalise(batch_of_images) # throws the exception as it should