Does this code calculate the correct mean and std to use for torchvision.transforms.Normalize()
?
import torch
import torchvision
import torchvision.transforms as transforms
# Calculate the mean and std of a dataset
def calc_data_mean(data_path):
traindata = torchvision.datasets.ImageFolder(
root=data_path,
transform=transforms.Compose([transforms.ToTensor()])
)
image_means = torch.stack([torch.mean(t, dim=(1, 2)) for t, c in traindata])
image_means = image_means.mean(0)
print('RGB mean:', image_means)
image_stds = torch.stack([torch.std(t, dim=(1, 2)) for t, c in traindata])
image_stds = image_stds.std(0)
print('RGB std:', image_stds)
quit()