This is my solution:
mean = 0.0
meansq = 0.0
count = 0
for index, data in enumerate(train_loader):
mean = data.sum()
meansq = meansq + (data**2).sum()
count += np.prod(data.shape)
total_mean = mean/count
total_var = (meansq/count) - (total_mean**2)
total_std = torch.sqrt(total_var)
print("mean: " + str(total_mean))
print("std: " + str(total_std))