Problem converting model to onnx

I am working on this code to convert my Unet++ model to Onnx, and the inferenced results from pytorch and onnx have huge differences.

I don’t need any dynamic axis, they are just there since they are suggested by GPT somehow.

Enviroment, windows, pytorch torch 2.7.1+cu128, onnxruntime-gpu 1.22.0, onnx 1.19.1.

Model Implementation

import torch
from torch import nn

__all__ = ["UNet", "NestedUNet"]


class VGGBlock(nn.Module):
    def __init__(self, in_channels, middle_channels, out_channels):
        super().__init__()
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(in_channels, middle_channels, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(middle_channels)
        self.conv2 = nn.Conv2d(middle_channels, out_channels, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        return out


class UNet(nn.Module):
    def __init__(self, num_classes, input_channels=3, **kwargs):
        super().__init__()

        nb_filter = [32, 64, 128, 256, 512]

        self.pool = nn.MaxPool2d(2, 2)
        self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)

        self.conv0_0 = VGGBlock(input_channels, nb_filter[0], nb_filter[0])
        self.conv1_0 = VGGBlock(nb_filter[0], nb_filter[1], nb_filter[1])
        self.conv2_0 = VGGBlock(nb_filter[1], nb_filter[2], nb_filter[2])
        self.conv3_0 = VGGBlock(nb_filter[2], nb_filter[3], nb_filter[3])
        self.conv4_0 = VGGBlock(nb_filter[3], nb_filter[4], nb_filter[4])

        self.conv3_1 = VGGBlock(nb_filter[3] + nb_filter[4], nb_filter[3], nb_filter[3])
        self.conv2_2 = VGGBlock(nb_filter[2] + nb_filter[3], nb_filter[2], nb_filter[2])
        self.conv1_3 = VGGBlock(nb_filter[1] + nb_filter[2], nb_filter[1], nb_filter[1])
        self.conv0_4 = VGGBlock(nb_filter[0] + nb_filter[1], nb_filter[0], nb_filter[0])

        self.final = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)

    def forward(self, input):
        x0_0 = self.conv0_0(input)
        x1_0 = self.conv1_0(self.pool(x0_0))
        x2_0 = self.conv2_0(self.pool(x1_0))
        x3_0 = self.conv3_0(self.pool(x2_0))
        x4_0 = self.conv4_0(self.pool(x3_0))

        x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1))
        x2_2 = self.conv2_2(torch.cat([x2_0, self.up(x3_1)], 1))
        x1_3 = self.conv1_3(torch.cat([x1_0, self.up(x2_2)], 1))
        x0_4 = self.conv0_4(torch.cat([x0_0, self.up(x1_3)], 1))

        output = self.final(x0_4)
        return output


class NestedUNet(nn.Module):
    def __init__(self, num_classes, input_channels=3, deep_supervision=False, **kwargs):
        super().__init__()

        nb_filter = [32, 64, 128, 256, 512]

        self.deep_supervision = deep_supervision

        self.pool = nn.MaxPool2d(2, 2)
        self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)

        self.conv0_0 = VGGBlock(input_channels, nb_filter[0], nb_filter[0])
        self.conv1_0 = VGGBlock(nb_filter[0], nb_filter[1], nb_filter[1])
        self.conv2_0 = VGGBlock(nb_filter[1], nb_filter[2], nb_filter[2])
        self.conv3_0 = VGGBlock(nb_filter[2], nb_filter[3], nb_filter[3])
        self.conv4_0 = VGGBlock(nb_filter[3], nb_filter[4], nb_filter[4])

        self.conv0_1 = VGGBlock(nb_filter[0] + nb_filter[1], nb_filter[0], nb_filter[0])
        self.conv1_1 = VGGBlock(nb_filter[1] + nb_filter[2], nb_filter[1], nb_filter[1])
        self.conv2_1 = VGGBlock(nb_filter[2] + nb_filter[3], nb_filter[2], nb_filter[2])
        self.conv3_1 = VGGBlock(nb_filter[3] + nb_filter[4], nb_filter[3], nb_filter[3])

        self.conv0_2 = VGGBlock(
            nb_filter[0] * 2 + nb_filter[1], nb_filter[0], nb_filter[0]
        )
        self.conv1_2 = VGGBlock(
            nb_filter[1] * 2 + nb_filter[2], nb_filter[1], nb_filter[1]
        )
        self.conv2_2 = VGGBlock(
            nb_filter[2] * 2 + nb_filter[3], nb_filter[2], nb_filter[2]
        )

        self.conv0_3 = VGGBlock(
            nb_filter[0] * 3 + nb_filter[1], nb_filter[0], nb_filter[0]
        )
        self.conv1_3 = VGGBlock(
            nb_filter[1] * 3 + nb_filter[2], nb_filter[1], nb_filter[1]
        )

        self.conv0_4 = VGGBlock(
            nb_filter[0] * 4 + nb_filter[1], nb_filter[0], nb_filter[0]
        )

        if self.deep_supervision:
            self.final1 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
            self.final2 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
            self.final3 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
            self.final4 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
        else:
            self.final = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)

    def forward(self, input):
        x0_0 = self.conv0_0(input)
        x1_0 = self.conv1_0(self.pool(x0_0))
        x0_1 = self.conv0_1(torch.cat([x0_0, self.up(x1_0)], 1))

        x2_0 = self.conv2_0(self.pool(x1_0))
        x1_1 = self.conv1_1(torch.cat([x1_0, self.up(x2_0)], 1))
        x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.up(x1_1)], 1))

        x3_0 = self.conv3_0(self.pool(x2_0))
        x2_1 = self.conv2_1(torch.cat([x2_0, self.up(x3_0)], 1))
        x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.up(x2_1)], 1))
        x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.up(x1_2)], 1))

        x4_0 = self.conv4_0(self.pool(x3_0))
        x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1))
        x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.up(x3_1)], 1))
        x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.up(x2_2)], 1))
        x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.up(x1_3)], 1))

        if self.deep_supervision:
            output1 = self.final1(x0_1)
            output2 = self.final2(x0_2)
            output3 = self.final3(x0_3)
            output4 = self.final4(x0_4)
            return [output1, output2, output3, output4]

        else:
            output = self.final(x0_4)
            return output
import torch
import numpy as np
from archs import NestedUNet, UNet
from pathlib import Path
import os
import onnxruntime as ort

experiment_dir = Path("models/wafer_defect_dataset_0918_NestedUNet_woDS_0")
load_model_path = str(experiment_dir/"model.pth")
export_model_path = str(experiment_dir/f"model.onnx")

if Path(export_model_path).exists():
    print(f"模型已存在: {export_model_path}")
    os.remove(export_model_path)
    print(f"删除模型: {export_model_path}")

# 加载模型并迁移到GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device != torch.device("cuda"):
    raise ValueError("GPU不可用")

if "NestedUNet" in load_model_path:
    model = NestedUNet(num_classes=1, input_channels=3, deep_supervision=False)
else:
    model = UNet(num_classes=1, input_channels=3, deep_supervision=False)
model.load_state_dict(torch.load(load_model_path, map_location=device))  # 加载权重时指定设备
model.to(device)
model.eval()

# 生成固定随机种子的输入(确保可复现)
torch.manual_seed(42)
dummy_input = torch.randn(1, 3, 512, 512, device=device)  # 直接在目标设备生成输入

with torch.no_grad():
    pytorch_output = model(dummy_input).cpu().numpy()

# 导出ONNX(禁用动态轴,避免维度相关优化导致的差异)
torch.onnx.export(
    model,
    (dummy_input,),
    export_model_path,
    input_names=["input"],
    output_names=["output"],
    opset_version=12,                 # 建议 13
    do_constant_folding=False,        
    # dynamic_axes=None,
    dynamic_axes={
    "input": {2: "height", 3: "width"},
    "output": {2: "height", 3: "width"},
    },
    verbose=False,
)

# 验证PyTorch/ONNX输出(CPU/CUDA,禁优化)
def run_and_diff(provider: str):
    try:
        so = ort.SessionOptions()  # type: ignore[attr-defined]
        try:
            so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL  # type: ignore[attr-defined]
        except Exception:
            pass
        session = ort.InferenceSession(str(export_model_path), sess_options=so, providers=[provider])  # type: ignore[attr-defined]
    except Exception:
        session = ort.InferenceSession(str(export_model_path), providers=[provider])  # type: ignore[attr-defined]
    onnx_input = {session.get_inputs()[0].name: dummy_input.cpu().numpy()}
    onnx_output = session.run(None, onnx_input)[0]
    diff = np.abs(pytorch_output - onnx_output)
    print(f"[{provider}] 最大差异: {diff.max():.6f}")
    print(f"[{provider}] 平均差异: {diff.mean():.6f}")

run_and_diff("CPUExecutionProvider")
if torch.cuda.is_available():
    try:
        run_and_diff("CUDAExecutionProvider")
    except Exception as e:
        print(f"CUDAExecutionProvider 运行失败: {e}")

Terminal output
模型已存在: models\wafer_defect_dataset_0918_NestedUNet_woDS_0\model.onnx
删除模型: models\wafer_defect_dataset_0918_NestedUNet_woDS_0\model.onnx
[CPUExecutionProvider] 最大差异: 2.782166
[CPUExecutionProvider] 平均差异: 0.457833
2025-10-16 19:34:30.1944079 [W:onnxruntime:, transformer_memcpy.cc:83 onnxruntime::MemcpyTransformer::ApplyImpl] 10 Memcpy nodes are added to the graph main_graph for CUDAExecutionProvider. It might have negative impact on performance (including unable to run CUDA graph). Set session_options.log_severity_level=1 to see the detail logs before this message.
[CUDAExecutionProvider] 最大差异: 2.393982
[CUDAExecutionProvider] 平均差异: 0.202145

Anyone please help me! :sob: :sob: :sob: :sob: :sob: :sob: :sob: :sob:

The problem was resolved. It was caused by the model weights, the variance of the normalization layer is very small due to errors during training (normalizing).