[JIT] Compilation-induced discrepancy in F.instance_norm when passing input as running stats

Bug Description

When scripting a model containing F.instance_norm with broadcasted input, JIT-compiled results differ from eager mode.

import torch
import torch.nn as nn
import torch.nn.functional as F


class NeuralModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.mlist = nn.ParameterList([
            nn.Parameter(torch.tensor([[
                [[5.0269, 3.4145, 3.8807],
                 [6.3197, 6.5815, 5.3826],
                 [3.4846, 4.3035, 6.9038]],

                [[3.7721, 3.0422, 4.3315],
                 [3.2518, 3.5617, 6.4604],
                 [3.5747, 5.1481, 6.6348]],

                [[3.0173, 3.5291, 6.8552],
                 [5.6125, 6.2321, 6.3142],
                 [4.7299, 4.2638, 5.2731]]
            ]], dtype=torch.float64))
        ])

    def forward(self, x):
        expanded = x.unsqueeze(-1).unsqueeze(-1)  # (3,) -> (3,1,1)
        multiplied = x * self.mlist[0]  # Broadcast multiply
        inst_norm = F.instance_norm(multiplied, x, x)
        log_softmax = F.log_softmax(multiplied, dim=-1)
        bilinear = F.interpolate(log_softmax, scale_factor=1.0, mode='bilinear')  # 修复:1.0 而不是 1

        return {
            'v0_0': expanded,
            'v6_0': inst_norm,
            'v2_0': bilinear
        }

input_data = torch.tensor([494.91649119, 528.01665228, 492.01463052], dtype=torch.float64)
model = NeuralModel()
with torch.no_grad():
    output_eager = model(input_data)

model_scripted = torch.jit.script(model)
with torch.no_grad():
    output_scripted = model_scripted(input_data)

print("NONJIT (v6_0):", output_eager['v6_0'])
print("JIT  (v6_0):", output_scripted['v6_0'])
print("consistency:", torch.allclose(output_eager['v6_0'], output_scripted['v6_0']))

output:

NONJIT (v6_0): tensor([[[[-0.0792, -1.1555, -0.9882],
          [ 0.9260,  1.4719,  0.1728],
          [-1.2785, -0.4180,  1.3487]],

         [[-0.5777, -0.9980, -0.1514],
          [-0.9931, -0.5555,  1.5384],
          [-0.7353,  0.7958,  1.6768]],

         [[-1.7688, -1.1584,  1.3313],
          [ 0.3497,  1.1958,  0.8923],
          [-0.3708, -0.5185,  0.0474]]]], dtype=torch.float64)
JIT  (v6_0): tensor([[[[ 0.1992, -1.2898, -1.0547],
          [ 1.2781,  1.2202,  0.0836],
          [-1.0879, -0.5852,  1.2365]],

         [[-0.3739, -1.1909, -0.2243],
          [-0.8553, -0.7344,  1.5646],
          [-0.5565,  0.6596,  1.7112]],

         [[-1.6616, -1.3545,  1.2870],
          [ 0.7235,  1.0046,  0.8354],
          [-0.0877, -0.7133, -0.0335]]]], dtype=torch.float64)
consistency: False

Versions

Collecting environment information…
PyTorch version: 2.0.1+cpu
Is debug build: False
CUDA used to build PyTorch: Could not collect
ROCM used to build PyTorch: N/A

OS: Could not collect
GCC version: Could not collect
Clang version: 20.1.2
CMake version: version 4.0.0
Libc version: N/A

Python version: 3.9.7 (tags/v3.9.7:1016ef3, Aug 30 2021, 20:19:38) [MSC v.1929 64 bit (AMD64)] (64-bit runtime)
Python platform: Windows-10-10.0.26100-SP0
Is CUDA available: False
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 4060 Laptop GPU
Nvidia driver version: 560.94
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] numpy==1.26.1
[pip3] torch==2.0.1
[conda] Could not collect