Dynamic output of nn.Upsample causes InstanceNorm2d error in onnx

Environment:
ubuntu 20.04, python 3.8, pytorch 1.10-cu113 via pip installing, numpy 1.21.4

Here is an example.

import torch
import torch.nn as nn


class TestModel(nn.Module):

    def __init__(self, num_features, init_size=None):
        super(TestModel, self).__init__()

        if init_size is None:
            self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        else:
            self.upsample = nn.Upsample(size=init_size * 2, mode='bilinear', align_corners=True)

        self.innorm = nn.InstanceNorm2d(num_features, affine=False)

    def forward(self, x):
        out = self.upsample(x)
        out = self.innorm(out)
        return out


if __name__ == '__main__':

    init_size = 64
    num_features = 32
    x = torch.randn(4, num_features, init_size, init_size)

    model = TestModel(num_features, init_size=None)
    # model = TestModel(num_features, init_size=init_size)
    output = model(x)
    print(output.size())
    # torch.Size([4, 32, 128, 128])

    torch.onnx.export(
        model, x, 'test.onnx',
        export_params=True,
        verbose=True,
        opset_version=14,  # 9 ~ 14
        input_names=['x'],
        output_names=['output']
    )

Following infomation is obtained from instance_norm in onnx/symbolic_opset9.py by print(input), which means that the output of nn.Upsample in onnx is dynamic.

8 defined in (%8 : Float(*, *, *, *, strides=[524288, 16384, 128, 1], requires_grad=0, device=cpu) = onnx::Resize[coordinate_transformation_mode="align_corners", cubic_coeff_a=-0.75, mode="linear", nearest_mode="floor"](%input.1, %7, %6) # /home/quqixun/miniconda3/envs/pt10/lib/python3.8/site-packages/torch/nn/functional.py:3731:0)

This leads to the following error in instance_norm in onnx/symbolic_opset9.py because channel_size is None while affine=False in nn.InstanceNorm2d.

RuntimeError: Unsupported: ONNX export of instance_norm for unknown channel size.

Questions:
Is there any method to make above code work?

Thanks a lot.

1 Like