Export to onnx: flexible scale factor of Resize onnx operator

I got an issue when exporting the torch.nn.functional.interpolate into onnx.
I want an input image with a dynamic shape, but the output is rescale so that the maximum dimension is always 1024 pixel, while keeping the ratio between the height/width of the image.
However, when exporting the model to onnx, I have to give a dummy input with a specific size (for example 512x384 pixels). I observe that the scale factor of the “Resize” operator in the onnx is always computed based on the shape of the dummy tensor, which causes a different between inference on onnx and on pytorch. What I expect to get is that the scale factor of “Resize” is computed from the shape of input tensor, but it’s not.
I attach the code to reproduce the issue, as well as the screen capture of the onnx model.
Do you have any idea how to make the scale factor of the Resize operator as a function of the input shape?
Thank you!

from pathlib import Path
import warnings
import torch
import torch.onnx
import onnxruntime as ort

class DummyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.max_size = 1024

    def forward(self, x):
        h, w = x.shape[2:]
        scale = float(self.max_size / max(h, w))
        return torch.nn.functional.interpolate(x, scale_factor=scale, align_corners=False, mode="bilinear")


def export_dummy_model(out_path: Path):
    # create dummy model
    model = DummyModel()
    # set dynamic axes
    dynamic_axes = {
        "input": {2: "H", 3: "W"},
        "output": {2: "H", 3: "W"},
    }

    with warnings.catch_warnings():
        warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
        warnings.filterwarnings("ignore", category=UserWarning)
        with open(str(out_path), "wb") as fp:
            torch.onnx.export(
                model,
                torch.randn(1, 3, 512, 384, dtype=torch.float),
                fp,
                export_params=True,
                verbose=False,
                opset_version=17,
                do_constant_folding=True,
                input_names=["input"],
                output_names=["output"],
                dynamic_axes=dynamic_axes,
            )


if __name__ == "__main__":
    # export dummy onnx
    onnx_path = Path(__file__).parent / "dummy.onnx"
    export_dummy_model(onnx_path)

    # Test with onnx model and pytorch model
    dummy_input = torch.randn(1, 3, 2048, 2000, dtype=torch.float)
    # onnxruntime
    ort_sess = ort.InferenceSession(onnx_path)
    output_onnx = ort_sess.run(None, {'input': dummy_input.numpy()})
    print(output_onnx[0].shape)  # (1, 3, 4096, 4000)
    # pytorch
    dummy_model = DummyModel()
    output_pytorch = dummy_model(dummy_input)
    print(output_pytorch.shape)  # (1, 3, 1024, 1000) => it's what I want

I found the solution from the documentation: passing ScriptModule instead of nn.Module when exporting to onnx:

        torch.onnx.export(
                torch.jit.script(model),
                dummy_input,
                fp,
                export_params=True,
                verbose=False,
                opset_version=17,
                do_constant_folding=True,
                input_names=["input"],
                output_names=["output"],
                dynamic_axes=dynamic_axes,
            )