I would post my code first and I would explain it later:
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from torchvision.transforms import Normalize
MEAN = torch.as_tensor([0.485, 0.456, 0.406])
STD = torch.as_tensor([0.229, 0.224, 0.225])
SCALE = 1 / STD
OFFSET = - SCALE * MEAN
R = torch.arange(196, dtype=float) / 255
G = torch.arange(196, dtype=float) / 255
B = torch.arange(196, dtype=float) / 255
IMG = torch.stack([R, G, B]).reshape(3, 14, 14)
def method_0():
# transform
transform = Normalize(MEAN, STD)
return transform(IMG)
def method_1():
# per channel norm (mean & std) -> reshape & stack
r_norm = ((R - MEAN[0]) / STD[0]).reshape(14, 14)
g_norm = ((G - MEAN[1]) / STD[1]).reshape(14, 14)
b_norm = ((B - MEAN[2]) / STD[2]).reshape(14, 14)
return torch.stack([r_norm, g_norm, b_norm])
def method_2():
# per channel norm (offset & scale) -> reshape & stack
r_norm = (R * SCALE[0] + OFFSET[0]).reshape(14, 14)
g_norm = (G * SCALE[1] + OFFSET[1]).reshape(14, 14)
b_norm = (B * SCALE[2] + OFFSET[2]).reshape(14, 14)
return torch.stack([r_norm, g_norm, b_norm])
def method_3():
# flatten channels, offset & scale
img_norm = (IMG.reshape(-1, 3) * SCALE + OFFSET).reshape(IMG.shape)
return img_norm
def method_4():
# flatten channels, mean & std
img_norm = ((IMG.reshape(-1, 3) - MEAN) / STD).reshape(IMG.shape)
return img_norm
def method_5():
# reshape scale and offset such that they can be broadcasted
img_norm = IMG * SCALE.reshape(3, 1, 1) + OFFSET.reshape(3, 1, 1)
return img_norm
def method_6():
# reshape mean and std such that they can be broadcasted
img_norm = (IMG - MEAN.reshape(3, 1, 1)) / STD.reshape(3, 1, 1)
return img_norm
def calculate_mse(arr):
re = []
for i in arr:
temp = []
for j in arr:
temp.append(
torch.log(torch.mean((i - j)**2)+1e-8).item()
)
re.append(temp)
return re
result = [
method_0(), method_1(), method_2(), method_3(), method_4(), method_5(), method_6()
]
mse_values = calculate_mse(result)
# Plot heatmap
plt.figure(figsize=(8, 6))
sns.heatmap(
mse_values, annot=True, cmap="YlGnBu", fmt=".4f",
xticklabels=[f"method_{i}" for i in range(7)],
yticklabels=[f"method_{i}" for i in range(7)]
)
plt.title("Heatmap of MSE Between Methods")
plt.xlabel("Methods")
plt.ylabel("Methods")
plt.show()
According to my intuition, these 7 methods should do exactly the same thing — normalize the input image according to mean & std calculated on ImageNet dataset. However, for some unknown reasons, method 3 and method 4 yield different results from others.
The heat map generated by this code should be something like:
Could anyone help me out? Or did I misunderstand broadcasting/reshape mechanism? Thanks in advance!
FYI, I’m running these on Ubuntu 22.04 LTS. Python version is 3.8, and PyTorch version is 2.4.