Onnx export using Torch gives wrong results

So i trained a resnet18 on patched images. I run the resnet (.pth) file on my unique test set and get good results. I try to export it to onnx and i get way worse results.

I’ve checked the input tiles and they’re identical, same max and min, same pre processing. I even exported the raw tiles and they were identical. I even looked at the individual pixel values and these were also the same.

Now when i run the image trough the onnx model the export logits are different.

(Same result on both windows and ubuntu

Normally all the files should be here, if something doesn’t work. Be sure to let me know because i changed it a bit so i wouldnt have to upload everything

To reproduce

Then run the export to onnx as follows:
uv run python export_to_onnx.py --model “resnet18” --weights “./resnet18_best.pth” --output “test.onnx”

Then
The root folder should contain 2 subfolders 1.Knot and 0.NoKnot. You can just copy the same image twice

uv run python test_onnx.py --onnx “./test.onnx” --root ./images

I also added 2 images, the first one is without knots and the logits of this image should be
[-3.072449 -2.8448505 -7.6250186 -6.501796 ] → is like this for .pth file but not for .onnx file

The second one has knots and he should return 3 positive classe (1) and 1 negative (0)

Urgency

It is quite urgent as i need it within this week.

Platform

Windows

OS Version

Windows 11

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

onnxruntime 1.23.2

ONNX Runtime API

Python

Architecture

X64

Execution Provider

Default CPU

Execution Provider Library Version

Cuda 12.8 (windows) and 12.6 (ubuntu

Model File

If needed i can try to upload it.

Is this a quantized model?

No

I also have ticker on the onnxruntime github:
[Performance] Outputs are completely different between .onnx model and .pth model · Issue #26504 · microsoft/onnxruntime

New additions:

From report:


No errors

Besides this i also added the latest torch nightly and onnxscript

torch 2.10.0.dev20251105+cu126
torchvision 0.25.0.dev20251105+cu126
onnxscript 0.5.6

But the issue seems to continue.

There is however 1 small difference between the tensors sent to the onnx and the pth. The .pth version is rounded off a bit earlier. 
I dont know if this is a big issue or not.

PTH:
tensor([[[[-1.9124, -1.8953, -1.8953,  ..., -1.6213, -1.6213, -1.6555],
[[[[-1.9124069 -1.8952821 -1.8952821 ... -1.6212862 -1.6212862 -1.6555357]
     [-1.9124, -1.8953, -1.8953,  ..., -1.6555, -1.6555, -1.6384],
   [-1.9124069 -1.8952821 -1.8952821 ... -1.6555357 -1.6555357 -1.6384109]
     [-1.9124, -1.8953, -1.8953,  ..., -1.6555, -1.6555, -1.6384]
...

ONNX:
[[[[-1.9124069 -1.8952821 -1.8952821 ... -1.6212862 -1.6212862 -1.6555357]
          [-1.9124, -1.8953, -1.8953,  ..., -1.6555, -1.6555, -1.6384],
   [-1.9124069 -1.8952821 -1.8952821 ... -1.6555357 -1.6555357 -1.6384109]
          [-1.9124, -1.8953, -1.8953,  ..., -1.6555, -1.6555, -1.6384],
   [-1.9124069 -1.8952821 -1.8952821 ... -1.6555357 -1.6555357 -1.6384109] 
...

Now i also tried to use the exact same preprocessing, so i changed my onnx model to just do the following:

  def forward(self, tiles):
      logits = self.model(tiles)
      return logits, logits

Then i added it right before my .pth model in my test

with torch.no_grad():
    for patches, label, path in test_loader:
        sess = create_session(onnx_path)
        in_name = sess.get_inputs()[0].name
        outputs = sess.run(None, {in_name: patches.cpu().numpy()})
        print(f"ONNX outputs: {outputs}")
        logits = model(patches) 

And for the same input there are 2 different outputs

ONNX outputs: [array([[-2.9795594],
[-3.2486563],
[-2.2178779],
[-2.025703 ]], dtype=float32)

Logits per patch: [-3.072449 -2.8448505 -7.6250186 -6.501796 ] (PTH)

Could the difference be GPU for pth and CPU for onnx?

Update:

I exported another older checkpoint and it also gives different results than the earlier onnx version (Same code). This makes me think that it might be a version mismatch somewhere. I’ll look futher into it if its something else.

(non patched version)

Edit: It is almost as the models differ, but the exact same .pth file was used. Or that the weights/biases are incorrectly parsed/binded to the onnx model

Code:

config.py

config = {
    "model_name": "resnet18",  # change to "mobilenet" or "resnet18" or "resnet50"
    # Standard configuration
    "image_width": 1224,
    "image_height": 1024,
    "batch_size": 16,
    "train_dir": "./data/train",
    "val_dir": "./data/val",
    "test_dir": "./data/test",
    # "test_dir": "./data",
    "model_save_dir_path": "./artifacts/",
    # training parameters
    "num_epochs": 150,
    "learning_rate": 0.0001,
    "reduce_lr_on_plateau": {
        "factor": 0.5,
        "patience": 2,
        "min_lr": 1e-6,
        "threshold": 1e-2 # lower (e.g. 1e-5) for stable learning
    },
    "early_stop_patience": 15,
    "freeze_epochs": 3  # number of epochs to keep backbone frozen
}

Orignal export_to_onnx.py:

"""Export a quadrant-patching model to ONNX format."""

import argparse
import os
import sys
import torch
from torch import nn



from torchvision import models
from torchvision.models import ResNet18_Weights
import torch.nn as nn

def load_model(image_width=64, image_height=64):
    model = models.resnet18(weights=ResNet18_Weights.DEFAULT)

    # Replace 3-channel conv1 with 1-channel
    # model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)

    # Replace the classifier for binary output
    model.fc = nn.Sequential(
        nn.Linear(model.fc.in_features, 1)
    )

    return model

# export_model_to_onnx_quadrants.py  (only the wrapper shown)

class ONNXWrapper(nn.Module):
    """
    NHWC uint8 full image(s) -> TL/TR/BL/BR tiles -> resize to (target_h, target_w)
    -> normalize (ImageNet) -> model -> per-tile probs [4B,2] & preds [4B,1].
    """
    def __init__(self, model, target_h: int, target_w: int):
        super().__init__()
        self.model = model
        self.register_buffer("mean_imagenet", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
        self.register_buffer("std_imagenet",  torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
        self.target_h = target_h
        self.target_w = target_w

    def forward(self, image):
        """
        Splits input image into quadrants, preprocesses, and runs through the model.
        Args:
            image (torch.Tensor): Input image tensor in NHWC uint8 format.
        Returns:
            tuple: Preprocessed tiles and predictions.
        """
        x = image
        tiles = []
        tile_h, tile_w = 1024, 1224  # hardcoded for quadrant tiling
        for i in range(2):
            for j in range(2):
                start_h = i * tile_h
                end_h = (i + 1) * tile_h
                start_w = j * tile_w
                end_w = (j + 1) * tile_w
                tile = x[:, start_h:end_h, start_w:end_w, :] # [B, tile_h, tile_w, 3]
                tiles.append(tile)

        tiles = torch.cat(tiles, dim=0)  # (4, tile_h, tile_w, 3)
        tiles = tiles.permute(0, 3, 1, 2).float()  # (4, 3, tile_h, tile_w)
        tiles = tiles / 255.0
        tiles = (tiles - self.mean_imagenet) / self.std_imagenet

        logits = self.model(tiles)  # [4, 1] or [4, 2]
        if logits.ndim == 1:
            logits = logits.unsqueeze(1)
        if logits.ndim == 2 and logits.shape[1] == 1:
            conf = torch.sigmoid(logits)
            preds = (conf > 0.5).to(logits.dtype)
        else:
            conf = torch.sigmoid(logits)
            preds = torch.argmax(conf, dim=1)
        return logits, preds # export logits to inspect raw scores

def export_to_onnx(model_name, weights_path, output_path, patch_w, patch_h,
                   dummy_full_w=2448, dummy_full_h=2048):
    """
    Exports the model to ONNX format using quadrant tiling.
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = load_model()
    model.load_state_dict(torch.load(weights_path, map_location=device))
    model.eval()

    wrapper = ONNXWrapper(model, target_h=patch_h, target_w=patch_w)

    dummy = torch.zeros(1, dummy_full_h, dummy_full_w, 3, dtype=torch.uint8)

    torch.onnx.export(
        wrapper,
        dummy,
        output_path,
        input_names=["image"],
        output_names=["tiles", "predictions"],
    )
if __name__ == "__main__":
    ap = argparse.ArgumentParser(description="Export quadrant-patching model to ONNX")
    ap.add_argument("--model", required=True, help="Model name (e.g. resnet18)")
    ap.add_argument("--weights", required=True, help="Path to .pth")
    ap.add_argument("--output", required=True, help="Path to .onnx")
    ap.add_argument("--width",  type=int, default=1224, help="Trained PATCH width")
    ap.add_argument("--height", type=int, default=1024, help="Trained PATCH height")
    # You can leave dummy size at 2448x2048; dynamic axes make runtime flexible
    ap.add_argument("--dummy_full_width",  type=int, default=2448)
    ap.add_argument("--dummy_full_height", type=int, default=2048)
    args = ap.parse_args()

    export_to_onnx(
        args.model, args.weights, args.output,
        patch_w=args.width, patch_h=args.height,
        dummy_full_w=args.dummy_full_width, dummy_full_h=args.dummy_full_height
    )

test_onnx.py:

# test_onnx_quadrants.py
"""
ONNX test script for RGB image classifier with 4-quadrant patching inside the ONNX graph.

Assumes the ONNX model:
- Accepts NHWC uint8 full images of arbitrary H×W
- Splits into TL/TR/BL/BR, resizes each tile to trained patch size
- Outputs:
    class_probabilities: [4B, 2]  (per-tile probs: [NoKnot, Knot])
    class_predictions:   [4B, 1]  (per-tile binary after sigmoid>0.5)

Image-level decision = 1 if ANY tile's class-1 probability >= threshold, else 0.
"""

import argparse
import os
from collections import Counter

import numpy as np
from PIL import Image
import onnxruntime as ort
from torchvision.datasets import ImageFolder
from PIL import ImageDraw
from sklearn.metrics import classification_report, confusion_matrix, precision_score, recall_score, f1_score


PATCH_NAMES = ["TL", "TR", "BL", "BR"]


def create_session(onnx_path: str, providers=None) -> ort.InferenceSession:
    if providers is None:
        providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
    try:
        return ort.InferenceSession(onnx_path, providers=providers)
    except Exception:
        return ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])


def load_image_nhwc_uint8(path: str, full_size: tuple[int, int]) -> np.ndarray:
    """
    Open image, convert to RGB, resize to (W,H)=full_size, return NHWC uint8 with batch dim.
    """
    W, H = full_size
    img = Image.open(path)
    arr = np.asarray(img, dtype=np.uint8)         # [H, W, 3]
    arr = np.expand_dims(arr, axis=0)             # [1, H, W, 3]
    return arr


def infer_image(sess: ort.InferenceSession, x_nhwc_uint8: np.ndarray, threshold: float):
    """
    Run one image through ONNX. Returns (overall_pred, pos_scores, fired_idxs).
    pos_scores are per-tile class-1 probabilities (length 4).
    """
    print(x_nhwc_uint8.shape)
    in_name = sess.get_inputs()[0].name
    outputs = sess.run(None, {in_name: x_nhwc_uint8})
    # Map names -> arrays
    out_names = [o.name for o in sess.get_outputs()]
    name2 = {o.name: arr for o, arr in zip(sess.get_outputs(), outputs)}

    if "class_probabilities" in out_names:
        probs = name2["class_probabilities"]  # [4B, 2]; B=1 -> [4, 2]
    else:
        # Fallback: assume first output are probs/logits with shape [4,2].
        probs = outputs[0]
        if not (probs.ndim == 2 and probs.shape[1] == 2):
            raise RuntimeError(f"Unexpected output shape {probs.shape}; expected [4,2].")

    # Per-tile positive (class-1) probabilities
    pos_scores = probs[:, 1]  # shape [4]
    fired_mask = pos_scores >= threshold
    overall_pred = int(np.any(fired_mask))
    fired_idxs = np.where(fired_mask)[0].tolist()
    return overall_pred, pos_scores, fired_idxs


def main():
    ap = argparse.ArgumentParser(description="Test ONNX model with quadrant tiling inside the graph")
    ap.add_argument("--onnx", required=True, help="Path to model.onnx")
    ap.add_argument("--root", required=True, help="Path to test ImageFolder root")
    ap.add_argument("--full_width", type=int, default=2448, help="Full image width fed to ONNX")
    ap.add_argument("--full_height", type=int, default=2048, help="Full image height fed to ONNX")
    ap.add_argument("--threshold", type=float, default=0.5, help="Tile-level positive threshold")
    args = ap.parse_args()

    # Build dataset (labels & paths from folder names)
    ds = ImageFolder(args.root)
    classes = ds.classes
    samples = ds.samples  # list of (path, label)

    print(f"Discovered {len(samples)} image(s) across classes: {classes}")
    if len(samples) == 0:
        print("No images found; exiting.")
        return

    # Create session
    sess = create_session(args.onnx)
    in_name = sess.get_inputs()[0].name
    print(f"ONNX input: {in_name} | providers: {sess.get_providers()}")

    y_true, y_pred = [], []
    wrong_details = []  # tuples: (path, true, pred, fired_names, pos_scores)

    # Evaluate
    for path, label in samples:
        x = load_image_nhwc_uint8(path, (args.full_width, args.full_height))
        try:
            pred, pos_scores, fired_idxs = infer_image(sess, x, args.threshold)
        except Exception as e:
            print(f"[ERROR] {path}: {e}")
            continue
        

        y_true.append(int(label))
        y_pred.append(int(pred))

        if pred != int(label):
            fired_names = [PATCH_NAMES[i] for i in fired_idxs]
            wrong_details.append((path, int(label), int(pred), fired_names, pos_scores.tolist()))
        # break

    # Reporting
    print("\n🧪 Test Results")
    print(classification_report(y_true, y_pred, target_names=classes, zero_division=0))

    print("🧮 Confusion Matrix")
    print(confusion_matrix(y_true, y_pred, labels=[0, 1]))

    print("\nClass balance (y_true):", Counter(y_true))
    print("Class balance (y_pred):", Counter(y_pred))

    if wrong_details:
        print("\nMisclassified images (tile probs shown as [TL, TR, BL, BR]):")
        for (path, y, p, fired, pos_scores) in wrong_details:
            # reorder pos_scores to TL,TR,BL,BR if needed (we assume wrapper order is TL,TR,BL,BR)
            tile_str = ", ".join(f"{name}={pos_scores[i]:.3f}" for i, name in enumerate(PATCH_NAMES))
            fired_str = ", ".join(fired) if fired else "none"
            print(f" - {path} | true={y} pred={p} | fired: [{fired_str}] | {tile_str}")
    else:
        print("\nNo misclassified images 🎉")

    # Return summary metrics if you want to call this as a function elsewhere
    acc = float(np.mean(np.array(y_true) == np.array(y_pred))) if y_true else float("nan")
    prec = precision_score(y_true, y_pred, zero_division=0)
    rec  = recall_score(y_true, y_pred, zero_division=0)
    f1   = f1_score(y_true, y_pred, zero_division=0)
    print(f"\nSummary: acc={acc:.4f} precision={prec:.4f} recall={rec:.4f} f1={f1:.4f}")


if __name__ == "__main__":
    main()

test_pth.py

"""
Test script for RGB image classifier with 4-quadrant patching per image.

Each image is split into: top-left (TL), top-right (TR), bottom-left (BL), bottom-right (BR).
Image-level prediction = 1 if any quadrant predicts 1, else 0.

Args:
    --model_path: Path to the trained model .pth file
Returns:
    Prints classification report and confusion matrix.
"""
import argparse
import numpy as np
from collections import Counter

import torch
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms
from sklearn.metrics import precision_score, recall_score, f1_score, classification_report, confusion_matrix
from torchvision.transforms import InterpolationMode
from config import config

from torchvision import models
from torchvision.models import ResNet18_Weights
import torch.nn as nn

def load_model(image_width=64, image_height=64):
    model = models.resnet18(weights=ResNet18_Weights.DEFAULT)

    # Replace 3-channel conv1 with 1-channel
    # model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)

    # Replace the classifier for binary output
    model.fc = nn.Sequential(
        nn.Linear(model.fc.in_features, 1)
    )

    return model

PATCH_NAMES = ["TL", "TR", "BL", "BR"]


def get_test_transforms(width, height):
    """Transforms applied per PATCH (not on the whole image)."""
    return transforms.Compose([
        transforms.Lambda(lambda img: img.convert("RGB")),
        transforms.Resize((height, width), interpolation=InterpolationMode.BILINEAR, antialias=True),  # resize each patch to model input
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])


class QuadrantPatchDataset(Dataset):
    """Wraps ImageFolder; for each image returns 4 patches (TL, TR, BL, BR)."""
    def __init__(self, root, transform=None):
        self.base = ImageFolder(root)
        self.transform = transform
        # Keep class names and image paths accessible
        self.classes = self.base.classes
        self.class_to_idx = self.base.class_to_idx
        self.samples = self.base.samples
        self.image_paths = [p for p, _ in self.base.samples]

    def __len__(self):
        return len(self.base)

    def _quadrants(self, img):
        # img is a PIL Image
        W, H = img.size
        xm, ym = W // 2, H // 2
        # Boxes: (left, upper, right, lower)
        print(f"Image size: W={W}, H={H}, xm={xm}, ym={ym}")
        boxes = [
            (0,   0,   xm, ym),    # TL
            (xm,  0,   W,  ym),    # TR
            (0,   ym,  xm, H),     # BL
            (xm,  ym,  W,  H),     # BR
        ]
        
        return [img.crop(b) for b in boxes]

    def __getitem__(self, idx):
        path, label = self.base.samples[idx]
        img = self.base.loader(path)  # PIL
        quads = self._quadrants(img)
        if self.transform is not None:
            quads = [self.transform(q) for q in quads]  # list of Tensors [C,H,W]
        # Stack to [4, C, H, W]
        patches = torch.stack(quads, dim=0)
        return patches, label, path


def test_model(config, model_path: str):
    """Test RGB image classifier with quadrant patching and print metrics."""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Dataset & loader
    transform = get_test_transforms(config["image_width"], config["image_height"])
    test_set = QuadrantPatchDataset(config["test_dir"], transform=transform)
    test_loader = DataLoader(test_set, batch_size=1, shuffle=False)

    # Load model
    model = load_model()
    state = torch.load(model_path, map_location=device)
    model.load_state_dict(state)
    model.eval().to(device)

    # Evaluation
    y_true, y_pred = [], []
    wrong_images = []
    wrong_details = []  # store which patches fired

    with torch.no_grad():
        for patches, label, path in test_loader:
            # patches: [1, 4, C, H, W] -> [4, C, H, W]
            patches = patches.squeeze(0).to(device)
            # Model forward on the 4 patches as a batch of size 4
            logits = model(patches)  # expect [4, 1] or [4] for binary BCEWithLogitsLoss setup
            if logits.ndim == 2 and logits.shape[1] == 1:
                logits = logits.squeeze(1)  # [4]
            # Convert logits to probabilities for thresholding
            probs = torch.sigmoid(logits)  # [4]
            preds_patch = (probs > 0.5).int().cpu().numpy()  # [4] in {0,1}

            # Aggregate to image-level prediction: 1 if any patch == 1
            img_pred = int(np.any(preds_patch == 1))
            y_pred.append(img_pred)
            y_true.append(int(label))

            if img_pred != int(label):
                # Which patches fired?
                fired = [PATCH_NAMES[i] for i, v in enumerate(preds_patch.tolist()) if v == 1]
                wrong_images.append(path)
                wrong_details.append((path, fired))

    # Reporting
    print("🧪 Test Results")
    print(classification_report(y_true, y_pred, target_names=test_set.classes, zero_division=0))

    print("🧮 Confusion Matrix")
    print(confusion_matrix(y_true, y_pred, labels=[0, 1]))

    print("\nClass balance (y_true):", Counter(y_true))
    print("Class balance (y_pred):", Counter(y_pred))

    if wrong_images:
        print("\nMisclassified images:")
        for (path, fired) in wrong_details:
            fired_str = ", ".join(fired) if fired else "none"
            print(f" - {path} | patches predicted 1: [{fired_str}]")
    else:
        print("\nNo misclassified images 🎉")

    return {
        "accuracy": float(np.mean(np.array(y_true) == np.array(y_pred))) if y_true else float("nan"),
        "precision": precision_score(y_true, y_pred, zero_division=0),
        "recall": recall_score(y_true, y_pred, zero_division=0),
        "f1": f1_score(y_true, y_pred, zero_division=0)
    }


if __name__ == "__main__":
    ap = argparse.ArgumentParser(description="Test RGB Classifier Model (Quadrant Patching)")
    ap.add_argument("--model_path", required=True, help="Path to the trained model .pth file")
    args = ap.parse_args()
    test_model(config, model_path=args.model_path)

This seemed to fix it

with torch.inference + dynamo False + training=torch.onnx.TrainingMode.EVAL + constant folding and keep_initializers_as_inputs=False

"""Export a quadrant-patching model to ONNX format."""

import argparse
import os
import sys
import torch
from torch import nn

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from models import load_model  # your existing loader

# export_model_to_onnx_quadrants.py  (only the wrapper shown)

class ONNXWrapper(nn.Module):
    """
    NHWC uint8 full image(s) -> TL/TR/BL/BR tiles -> resize to (target_h, target_w)
    -> normalize (ImageNet) -> model -> per-tile probs [4B,2] & preds [4B,1].
    """
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.register_buffer("mean_imagenet", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
        self.register_buffer("std_imagenet",  torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))

    def forward(self, x):
        """
        Splits input image into quadrants, preprocesses, and runs through the model.
        Args:
            image (torch.Tensor): Input image tensor in NHWC uint8 format.
        Returns:
            tuple: Preprocessed tiles and predictions.
        """

        tiles = []
        tile_h, tile_w = 1024, 1224  # hardcoded for quadrant tiling
        for i in range(2):
            for j in range(2):
                start_h = i * tile_h
                end_h = (i + 1) * tile_h
                start_w = j * tile_w
                end_w = (j + 1) * tile_w
                tile = x[:, start_h:end_h, start_w:end_w, :] # [B, tile_h, tile_w, 3]
                tiles.append(tile)

        tiles = torch.cat(tiles, dim=0)  # (4, tile_h, tile_w, 3)
        tiles = tiles.permute(0, 3, 1, 2).float()  # (4, 3, tile_h, tile_w)
        tiles = tiles / 255.0
        tiles = (tiles - self.mean_imagenet) / self.std_imagenet

        logits = self.model(tiles)  # [4, 1] or [4, 2]
        if logits.ndim == 1:
            logits = logits.unsqueeze(1)
        if logits.ndim == 2 and logits.shape[1] == 1:
            conf = torch.sigmoid(logits)
            preds = (conf > 0.5).to(logits.dtype)
        else:
            conf = torch.sigmoid(logits)
            preds = torch.argmax(conf, dim=1)
        return conf, preds 

def export_to_onnx(model_name, weights_path, output_path,
                   dummy_full_w=2448, dummy_full_h=2048):
    """
    Exports the model to ONNX format using quadrant tiling.
    """
    device = torch.device("cpu")
    print(f"using device: {device}")

    model = load_model(model_name, 1224, 1024)
    model.load_state_dict(torch.load(weights_path, map_location=device), strict=False)
    model.eval().cpu()

    wrapper = ONNXWrapper(model)

    dummy = torch.zeros(1, 2048, 2448, 3, dtype=torch.uint8)

    # torch.onnx.export(
    #     wrapper,
    #     dummy,
    #     output_path,
    #     input_names=["image"],
    #     output_names=["probabilities", "predictions"],
    #     dynamo=True,
    #     report=True,
    #     export_params=True,
    #     opset_version=18,
    #     do_constant_folding=True
    # )
    model.eval()
    with torch.inference_mode():
        torch.onnx.export(
            wrapper,                     
            dummy,                        
            output_path,
            opset_version=18,
            dynamo=False,                
            training=torch.onnx.TrainingMode.EVAL,
            do_constant_folding=True,
            keep_initializers_as_inputs=False,
            input_names=["input"],
            output_names=["probabilities", "predictions"],
        )

if __name__ == "__main__":
    ap = argparse.ArgumentParser(description="Export quadrant-patching model to ONNX")
    ap.add_argument("--model", required=True, help="Model name (e.g. resnet18)")
    ap.add_argument("--weights", required=True, help="Path to .pth")
    ap.add_argument("--output", required=True, help="Path to .onnx")
    ap.add_argument("--width",  type=int, default=1224, help="Trained PATCH width")
    ap.add_argument("--height", type=int, default=1024, help="Trained PATCH height")
    # You can leave dummy size at 2448x2048; dynamic axes make runtime flexible
    ap.add_argument("--dummy_full_width",  type=int, default=2448)
    ap.add_argument("--dummy_full_height", type=int, default=2048)
    args = ap.parse_args()

    export_to_onnx(
        args.model, args.weights, args.output
    )