Me and @FilipAndersson245 found out that the correct way to unnormalize is:
x * std + mean
We also had to clamp a few values outside of [0,1].
For a single image the code would look something like this:
def inv_normalize(img):
mean = torch.Tensor([0.485, 0.456, 0.406]).unsqueeze(-1)
std= torch.Tensor([0.229, 0.224, 0.225]).unsqueeze(-1)
img = (img.view(3, -1) * std + mean).view(img.shape)
img = img.clamp(0, 1)
return img
Feel free to help if the code can be written in a simpler way!