I got an issue when exporting the torch.nn.functional.interpolate into onnx.
I want an input image with a dynamic shape, but the output is rescale so that the maximum dimension is always 1024 pixel, while keeping the ratio between the height/width of the image.
However, when exporting the model to onnx, I have to give a dummy input with a specific size (for example 512x384 pixels). I observe that the scale factor of the “Resize” operator in the onnx is always computed based on the shape of the dummy tensor, which causes a different between inference on onnx and on pytorch. What I expect to get is that the scale factor of “Resize” is computed from the shape of input tensor, but it’s not.
I attach the code to reproduce the issue, as well as the screen capture of the onnx model.
Do you have any idea how to make the scale factor of the Resize operator as a function of the input shape?
Thank you!
from pathlib import Path
import warnings
import torch
import torch.onnx
import onnxruntime as ort
class DummyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.max_size = 1024
def forward(self, x):
h, w = x.shape[2:]
scale = float(self.max_size / max(h, w))
return torch.nn.functional.interpolate(x, scale_factor=scale, align_corners=False, mode="bilinear")
def export_dummy_model(out_path: Path):
# create dummy model
model = DummyModel()
# set dynamic axes
dynamic_axes = {
"input": {2: "H", 3: "W"},
"output": {2: "H", 3: "W"},
}
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
warnings.filterwarnings("ignore", category=UserWarning)
with open(str(out_path), "wb") as fp:
torch.onnx.export(
model,
torch.randn(1, 3, 512, 384, dtype=torch.float),
fp,
export_params=True,
verbose=False,
opset_version=17,
do_constant_folding=True,
input_names=["input"],
output_names=["output"],
dynamic_axes=dynamic_axes,
)
if __name__ == "__main__":
# export dummy onnx
onnx_path = Path(__file__).parent / "dummy.onnx"
export_dummy_model(onnx_path)
# Test with onnx model and pytorch model
dummy_input = torch.randn(1, 3, 2048, 2000, dtype=torch.float)
# onnxruntime
ort_sess = ort.InferenceSession(onnx_path)
output_onnx = ort_sess.run(None, {'input': dummy_input.numpy()})
print(output_onnx[0].shape) # (1, 3, 4096, 4000)
# pytorch
dummy_model = DummyModel()
output_pytorch = dummy_model(dummy_input)
print(output_pytorch.shape) # (1, 3, 1024, 1000) => it's what I want