The messy output is quite normal, as matplotlib
either slips the input or tries to scale it, which creates these kind of artifacts (also because you are normalizing channel-wise with different values).
If you would like to visualize the images, you should use the raw images (in [0, 255]
) or the normalized ones (in [0, 1]
).
Alternatively, you could also unnormalize them, but I think the first approach would be simpler.
If you are using a custom Dataset
, just add another load_data
function and use it for visualization:
class MyDataset(Dataset):
def __init__(self, image_paths, targets, transform=None):
self.image_paths = image_paths
self.targets = targets
self.transform = transform
def load_image(self, index):
image_path = self.image_paths[index]
img = Image.open(image_path)
return img
def __getitem__(self, index):
x = self.load_image(index)
y = self.targets[index]
if self.transform:
x = self.transform(x)
return x, y
def __len__(self):
return len(self.image_paths)
image_paths = [...]
targets = ...
dataset = MyDataset(image_paths, targets, transform=transforms.ToTensor())
img_to_vis = dataset.load_image(index=0)
PS: Unrelated to your question, but your PyTorch version is quite old. I would recommend to update it to the latest stable version. You’ll find the install instructions here.