PyTorch ONNX JavaScript model with ReflectionPad and ConvTranspose

I’d like to export a pretrained model to ONNX format so that I can run it from a browser with JavaScript. The model uses ReflectionPad and ConvTranspose. If I export with an opset version <=10 JS complains that ConvTranspose is not implemented and if I export with an opset version >= 11 JS complains that there are int64 values in my model which it can’t deal with; there aren’t, but ReflectionPad seems to create them.

The model definition is as follows:

import torch 
import onnx 

print("pytorch version :", torch.__version__)
print("onnx version    :", onnx.__version__)

class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        
        self.reflectionpad = nn.ReflectionPad1d(3)
        self.conv = nn.ConvTranspose1d(4, 4, kernel_size=3)
        
    def forward(self, x):
        x = self.reflectionpad(x)
        return self.conv(x)

model_g = MyModule()
x = torch.randn(1, 4, 4)

The bug_example.js scipt:

async function runExample() {
    // Create an ONNX inference session with default backend.
    const session = new onnx.InferenceSession();

    await session.loadModel("toy_model.onnx");
    const x = new Float32Array(1 * 4 * 4).fill(1);
    const tensorX = new onnx.Tensor(x, 'float32', [1, 4, 4]);

    const outputMap = await session.run([tensorX]);
    const outputData = outputMap.get('output');

    // Check if result is expected.
    console.log('ok');
}

The bug_example.html script:

<html>
<head>
  <script src="https://cdn.jsdelivr.net/npm/onnxjs/dist/onnx.min.js"></script>
  <script src="./bug_example.js"></script>
</head>
<body>
  <div><input type="button" value="Run" onclick="runExample()"/></div>
</body>
</html>

Using opset=10

(opset=9 behaves the same)

torch.onnx.export(model_g, x, "toy_model.onnx", verbose = True, opset_version=10)

output:

pytorch version : 1.5.1
onnx version    : 1.7.0
graph(%input : Float(1, 4, 4),
      %conv.weight : Float(4, 4, 3),
      %conv.bias : Float(4)):
  %3 : Float(1, 4, 10) = onnx::Pad[mode="reflect", pads=[0, 0, 3, 0, 0, 3]](%input) # /home/bram/miniconda3/envs/whispp_env/lib/python3.7/site-packages/torch/nn/functional.py:3397:0
  %4 : Float(1, 4, 12) = onnx::ConvTranspose[dilations=[1], group=1, kernel_shape=[3], pads=[0, 0], strides=[1]](%3, %conv.weight, %conv.bias) # /home/bram/miniconda3/envs/whispp_env/lib/python3.7/site-packages/torch/nn/modules/conv.py:647:0
  return (%4)

Running JS in the browser gives me Uncaught (in promise) TypeError: cannot resolve operator 'ConvTranspose' with opsets: ai.onnx v10

Using opset=11

(opset=12 behaves the same)

torch.onnx.export(model_g, x, "toy_model.onnx", verbose = True, opset_version=11)

output:

pytorch version : 1.5.1
onnx version    : 1.7.0
graph(%input : Float(1, 4, 4),
      %conv.weight : Float(4, 4, 3),
      %conv.bias : Float(4),
      %27 : Long(),
      %28 : Long(2)):
  %3 : int[] = onnx::Constant[value= 3  3 [ CPULongType{2} ]]()
  %4 : Tensor = onnx::Constant[value={0}]()
  %5 : Tensor = onnx::Shape(%3)
  %6 : Tensor = onnx::Gather[axis=0](%5, %4)
  %10 : LongTensor = onnx::Sub(%27, %6)
  %12 : Tensor = onnx::ConstantOfShape[value={0}](%10)
  %13 : Tensor = onnx::Concat[axis=0](%28, %12)
  %14 : Tensor = onnx::Constant[value=-1  2 [ CPULongType{2} ]]()
  %15 : Tensor = onnx::Reshape(%13, %14)
  %16 : Tensor = onnx::Constant[value={0}]()
  %17 : Tensor = onnx::Constant[value={-1}]()
  %18 : Tensor = onnx::Constant[value={-9223372036854775807}]()
  %19 : Tensor = onnx::Constant[value={-1}]()
  %20 : Tensor = onnx::Slice(%15, %17, %18, %16, %19)
  %21 : Tensor = onnx::Transpose[perm=[1, 0]](%20)
  %22 : Tensor = onnx::Constant[value={-1}]()
  %23 : Tensor = onnx::Reshape(%21, %22)
  %24 : Tensor = onnx::Cast[to=7](%23)
  %25 : Float(1, 4, 10) = onnx::Pad[mode="reflect"](%input, %24) # /home/bram/miniconda3/envs/whispp_env/lib/python3.7/site-packages/torch/nn/functional.py:3397:0
  %26 : Float(1, 4, 12) = onnx::ConvTranspose[dilations=[1], group=1, kernel_shape=[3], pads=[0, 0], strides=[1]](%25, %conv.weight, %conv.bias) # /home/bram/miniconda3/envs/whispp_env/lib/python3.7/site-packages/torch/nn/modules/conv.py:647:0
  return (%26)

Running JS in the browser gives me Uncaught (in promise) TypeError: int64 is not supported

I guess the earlier versions didn’t have all functionalities yet, but the more recent version has a bug related to padding layers somehow? Note that using opset=10 without convTranspose works fine.

What is the advised way to deal with this?

I have the same issue too. According to the ONNX.js document, ConvTranspose is not supported. [https://github.com/microsoft/onnxjs/blob/master/docs/operators.md]

Can you open an issue in ONNXJS repo?

I believe these two issues on the onnx.js forums explain the problem best.


I have replied there with a link to this issue. I hope it gets picked up on the onnxjs side