Convert pythorch pth model to onnx with fixed width and variable height

I also posted the question here.

My code uses PyTorch to perform segmentation annotations on PNG images. The input images have a width of 512 pixels or a multiple of this, but the height can range from 400 to 900 pixels. The code, along with the PyTorch model (*.pth file), works as expected.

I am currently attempting to convert my *.pth model to *.onnx. The code itself hasn’t changed much (only modifications related to ONNX, naturally), but the issue I am encountering is with the model conversion.

Here is my code for the model conversion:

import onnx
import torch
import torch.nn as nn
import torch.nn.functional as F

# pip install torch onnx


class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True):
        super().__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor)
        self.up1 = Up(1024, 512, bilinear)
        self.up2 = Up(512, 256, bilinear)
        self.up3 = Up(256, 128, bilinear)
        self.up4 = Up(128, 64 * factor, bilinear)
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits


class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(nn.MaxPool2d(2), DoubleConv(in_channels, out_channels))

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()
        self.up: nn.Upsample | nn.ConvTranspose2d

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels // 2, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = torch.tensor([x2.size()[2] - x1.size()[2]])
        diffX = torch.tensor([x2.size()[3] - x1.size()[3]])

        x1 = F.pad(
            x1,
            [
                torch.div(diffX, 2, rounding_mode="floor"),
                diffX - torch.div(diffX, 2, rounding_mode="floor"),
                torch.div(diffY, 2, rounding_mode="floor"),
                diffY - torch.div(diffY, 2, rounding_mode="floor"),
            ],
        )
        # if you have padding issues, see
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)


def convert_pytorch_to_onnx(pytorch_model_path, onnx_model_path):
    # Load the PyTorch model
    model = UNet(n_channels=1, n_classes=1)
    model.load_state_dict(torch.load(pytorch_model_path, map_location="cpu"))
    model.eval()

    # Create dummy input with dynamic size
    dummy_input = torch.randn(1, 1, 512, 512)  # Height [400 to 900], Width fixed at 512

    # Export the model
    torch.onnx.export(
        model,  # model being run
        dummy_input,  # model input (or a tuple for multiple inputs)
        onnx_model_path,  # where to save the model
        export_params=True,  # store the trained parameter weights inside the model file
        opset_version=20,  # the ONNX version to export the model to
        do_constant_folding=True,  # whether to execute constant folding for optimization
        input_names=["input"],  # the model's input names
        output_names=["output"],  # the model's output names
        dynamic_axes={
            "input": {0: "batch_size", 2: "height", 3: "width"},  # variable length axes
            "output": {0: "batch_size", 2: "height", 3: "width"},
        },
    )

    # Verify the model
    onnx_model = onnx.load(onnx_model_path)
    onnx.checker.check_model(onnx_model)

    print(f"Model {pytorch_model_path} converted to {onnx_model_path}")


# List of models to convert
models = [
    "models/case1.pth",
    "models/case2.pth",
]

# Convert each model
for model_path in models:
    onnx_path = model_path.replace(".pth", ".onnx")
    convert_pytorch_to_onnx(model_path, onnx_path)

Using the ONNX models created with:

dummy_input = torch.randn(1, 1, 512, 512)  # Height [400 to 900], Width fixed at 512

Inference on images with heights between 400 and 600 pixels seems to work similarly to my PyTorch code. However, images with a height of 800 pixels produce incorrect results when they work at all.

Conversely, if I convert the model using:

dummy_input = torch.randn(1, 1, 885, 512)

An image with a height of 885 pixels works perfectly.

I’m not an expert in PyTorch or ONNX. For now, the only “workable” solution I’ve found is to use 512x512 in the dummy_input and add a few lines to my Python code to crop the top and bottom of input images if they are taller than 512 pixels. However, the results are not identical to those of the original PyTorch code.

I’m unsure exactly what PyTorch does with variable-height inputs — I didn’t build the model — but from inspecting the UNet class, it’s clear that it is not cropping the input. It’s more likely “reducing” the height until it fits the model’s structure.

If that’s the case, how can I replicate this behaviour with an ONNX model?

I’m also wondering if torch-onnx: github . com / justinchuby / torch-onnx or torch-dynamo: pytorch . org / docs / stable/ onnx.html can help me here.