is this correct? I am not sure what 0, 2, 3 in the code represent. It is from a YouTube video. I see others commenting about their confusion on 0, 2, 3 too.
def get_mean_std(loader):
# VAR[X] = E[X**2] - E[X]**2
channels_sum, channels_squared_sum, num_batches = 0, 0, 0
for data, _ in loader:
channels_sum += torch.mean(data, dim=[0,2,3])
channels_squared_sum += torch.mean(data**2, dim=[0,2,3])
num_batches += 1
mean = channels_sum/num_batches
std = (channels_squared_sum/num_batches - mean**2)**0.5
return mean, std
train_mean, train_std = get_mean_std(dataloaders_dict['train'])
print(train_mean, train_std)
tensor([0.7031, 0.5487, 0.6750]) tensor([0.2115, 0.2581, 0.1952])
test_mean, test_std = get_mean_std(dataloaders_dict['test'])
print(test_mean, test_std)
tensor([0.7048, 0.5509, 0.6763]) tensor([0.2111, 0.2576, 0.1979])
val_mean, val_std = get_mean_std(dataloaders_dict['val'])
print(val_mean, val_std)
tensor([0.7016, 0.5549, 0.6784]) tensor([0.2099, 0.2583, 0.1998])