Understanding transform.Normalize( )

The messy output is quite normal, as matplotlib either slips the input or tries to scale it, which creates these kind of artifacts (also because you are normalizing channel-wise with different values).

If you would like to visualize the images, you should use the raw images (in [0, 255]) or the normalized ones (in [0, 1]).
Alternatively, you could also unnormalize them, but I think the first approach would be simpler.

If you are using a custom Dataset, just add another load_data function and use it for visualization:

class MyDataset(Dataset):
    def __init__(self, image_paths, targets, transform=None):
        self.image_paths = image_paths
        self.targets = targets
        self.transform = transform

    def load_image(self, index):
        image_path = self.image_paths[index]
        img = Image.open(image_path)
        return img

    def __getitem__(self, index):
        x = self.load_image(index)
        y = self.targets[index]
        
        if self.transform:
            x = self.transform(x)
            
        return x, y
    
    def __len__(self):
        return len(self.image_paths)

image_paths = [...]
targets = ...
dataset = MyDataset(image_paths, targets, transform=transforms.ToTensor())
img_to_vis = dataset.load_image(index=0)

PS: Unrelated to your question, but your PyTorch version is quite old. I would recommend to update it to the latest stable version. You’ll find the install instructions here.

6 Likes