Would broadcast or reshape affect the result?

I would post my code first and I would explain it later:

import torch
import matplotlib.pyplot as plt
import seaborn as sns
from torchvision.transforms import Normalize

MEAN = torch.as_tensor([0.485, 0.456, 0.406])
STD = torch.as_tensor([0.229, 0.224, 0.225])
SCALE = 1 / STD
OFFSET = - SCALE * MEAN

R = torch.arange(196, dtype=float) / 255
G = torch.arange(196, dtype=float) / 255
B = torch.arange(196, dtype=float) / 255
IMG = torch.stack([R, G, B]).reshape(3, 14, 14)

def method_0():
    # transform
    transform = Normalize(MEAN, STD)
    return transform(IMG)

def method_1():
    # per channel norm (mean & std) -> reshape & stack
    r_norm = ((R - MEAN[0]) / STD[0]).reshape(14, 14)
    g_norm = ((G - MEAN[1]) / STD[1]).reshape(14, 14)
    b_norm = ((B - MEAN[2]) / STD[2]).reshape(14, 14)
    return torch.stack([r_norm, g_norm, b_norm])

def method_2():
    # per channel norm (offset & scale) -> reshape & stack
    r_norm = (R * SCALE[0] + OFFSET[0]).reshape(14, 14)
    g_norm = (G * SCALE[1] + OFFSET[1]).reshape(14, 14)
    b_norm = (B * SCALE[2] + OFFSET[2]).reshape(14, 14)
    return torch.stack([r_norm, g_norm, b_norm])

def method_3():
    # flatten channels, offset & scale
    img_norm = (IMG.reshape(-1, 3) * SCALE + OFFSET).reshape(IMG.shape)
    return img_norm

def method_4():
    # flatten channels, mean & std
    img_norm = ((IMG.reshape(-1, 3) - MEAN) / STD).reshape(IMG.shape)
    return img_norm

def method_5():
    # reshape scale and offset such that they can be broadcasted
    img_norm = IMG * SCALE.reshape(3, 1, 1) + OFFSET.reshape(3, 1, 1)
    return img_norm

def method_6():
    # reshape mean and std such that they can be broadcasted
    img_norm = (IMG - MEAN.reshape(3, 1, 1)) / STD.reshape(3, 1, 1)
    return img_norm

def calculate_mse(arr):
    re = []
    for i in arr:
        temp = []
        for j in arr:
            temp.append(
                torch.log(torch.mean((i - j)**2)+1e-8).item()
            )
        re.append(temp)
        
    return re

result = [
    method_0(), method_1(), method_2(), method_3(), method_4(), method_5(), method_6()
]

mse_values = calculate_mse(result)

# Plot heatmap
plt.figure(figsize=(8, 6))
sns.heatmap(
    mse_values, annot=True, cmap="YlGnBu", fmt=".4f", 
    xticklabels=[f"method_{i}" for i in range(7)], 
    yticklabels=[f"method_{i}" for i in range(7)]
)
plt.title("Heatmap of MSE Between Methods")
plt.xlabel("Methods")
plt.ylabel("Methods")
plt.show()

According to my intuition, these 7 methods should do exactly the same thing — normalize the input image according to mean & std calculated on ImageNet dataset. However, for some unknown reasons, method 3 and method 4 yield different results from others.

The heat map generated by this code should be something like:

Could anyone help me out? Or did I misunderstand broadcasting/reshape mechanism? Thanks in advance!

FYI, I’m running these on Ubuntu 22.04 LTS. Python version is 3.8, and PyTorch version is 2.4.