Only tuples, lists and Variables are supported as JIT inputs/outputs. Dictionaries and strings are also accepted, but their usage is not recommended. Here, received an input of unsupported type: int

Hi I’m trying to export PyTorch custom layer to to onnx and got this error:

Only tuples, lists and Variables are supported as JIT inputs/outputs. Dictionaries and strings are also accepted, but their usage is not recommended. Here, received an input of unsupported type: int

I checked same issues and tried to resolve:

But cannot do it.
There is my code:

import torch
import dcn_op_v2
from torch.autograd import Function
from torch.nn import Module
import torch.nn as nn
import math


class DeformableConv2DFunction(Function):
    @staticmethod
    def forward(ctx, input_tensor, weight, bias, offset, mask, stride, pad, dilation, deformable_groups):
        ctx.stride_h = stride[0]
        ctx.stride_w = stride[1]
        ctx.pad_h = pad[0]
        ctx.pad_w = pad[1]
        ctx.dilation_h = dilation[0]
        ctx.dilation_w = dilation[1]
        ctx.deformable_groups = deformable_groups

        output = dcn_op_v2.forward(
            input_tensor,
            weight,
            bias,
            offset,
            mask,
            ctx.stride_h, ctx.stride_w,
            ctx.pad_h, ctx.pad_w,
            ctx.dilation_h, ctx.dilation_w,
            ctx.deformable_groups
        )
        ctx.save_for_backward(input_tensor, weight, offset, mask, bias)
        return output
    
    @staticmethod
    def backward(ctx, *grad_outputs):
        input_tensor, weight, offset, mask, bias = ctx.saved_tensors
        grad_input, grad_weight, grad_bias, grad_offset, grad_mask = dcn_op_v2.backward(
            input_tensor,
            weight,
            bias,
            offset,
            mask,
            grad_outputs[0],
            ctx.stride_h, ctx.stride_w,
            ctx.pad_h, ctx.pad_w,
            ctx.dilation_h, ctx.dilation_w,
            ctx.deformable_groups
        )
        
        return grad_input, grad_weight, grad_bias, grad_offset, grad_mask, \
            None, None, None, None

class DeformableConv2DLayer(Module):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride,
                 padding,
                 dilation,
                 deformable_groups):
        super().__init__()
        def typexam(x):
            if type(x)==int:
                return (x, x)
            elif type(x)==tuple and len(x)==2:
                return x
            else:
                raise TypeError
        kernel_size = typexam(kernel_size)
        stride = typexam(stride)
        padding = typexam(padding)
        dilation = typexam(dilation)
        self.stride = stride
        self.pad = padding
        self.dilation = dilation
        self.deformable_groups = deformable_groups
        self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels, *kernel_size))
        self.bias = nn.Parameter(torch.Tensor(out_channels))
        self.conv_offset_mask = nn.Conv2d(in_channels,
                                          self.deformable_groups * 3 * kernel_size[0] * kernel_size[1],
                                          kernel_size=kernel_size,
                                          stride=self.stride,
                                          padding=self.pad,
                                          bias=True)
        self.reset_parameters(in_channels, kernel_size)
        self.init_offset()

    def init_offset(self):
        self.conv_offset_mask.weight.data.zero_()
        self.conv_offset_mask.bias.data.zero_()

    def reset_parameters(self, in_channels, kernel_size):
        n = in_channels
        for k in kernel_size:
            n *= k
        stdv = 1. / math.sqrt(n)
        self.weight.data.uniform_(-stdv, stdv)
        self.bias.data.zero_()

    def forward(self, inputs):
        out = self.conv_offset_mask(inputs)
        o1, o2, mask = torch.chunk(out, 3, dim=1)
        offset = torch.cat((o1, o2), dim=1)
        mask = torch.sigmoid(mask)
        return DeformableConv2DFunction.apply(inputs, self.weight, self.bias, offset, mask, self.stride, self.pad, self.dilation, self.deformable_groups)

from torch.onnx import register_custom_op_symbolic

def dcn_v2(g, self):
    return g.op("ai.onnx.contrib::Dcn_v2", self)

register_custom_op_symbolic('::dcn_v2', dcn_v2, 1)


import io
import onnx


N, inC, inH, inW = 2, 2, 4, 4
outC = 2
kH, kW = 3, 3
stride = (1,1)
padding = (1,1)
dilation= (1,1)
deformable_groups = 1

t = torch.randn(N, inC, inH, inW)

offset = torch.randn(N, deformable_groups * 2 * kW * kH, inH, inW)

mask = torch.rand(N, deformable_groups * 1 * kW * kH, inH, inW)
# mask.data.zero_()
mask = torch.sigmoid(mask)

weight = torch.randn(outC, inC, kH, kW)
bias = torch.rand(outC)



# Export model to ONNX
t_model = DeformableConv2DLayer(inC, outC, (kH,kW), stride, padding, dilation, deformable_groups)
input_nn = (t, weight, bias, offset, mask, stride, padding, dilation, torch.tensor(deformable_groups, dtype=torch.int32))
with torch.no_grad():
    torch.onnx.export(t_model, input_nn, 'models/dcn.onnx',  export_params=True)

Full ouput of error is:

Traceback (most recent call last):
  File "custom_onnx_op.py", line 158, in <module>
    torch.onnx.export(t_model, input_nn, 'models/dcn.onnx',  export_params=True)
  File "/home/appuser/.local/lib/python3.6/site-packages/torch/onnx/__init__.py", line 276, in export
    custom_opsets, enable_onnx_checker, use_external_data_format)
  File "/home/appuser/.local/lib/python3.6/site-packages/torch/onnx/utils.py", line 94, in export
    use_external_data_format=use_external_data_format)
  File "/home/appuser/.local/lib/python3.6/site-packages/torch/onnx/utils.py", line 701, in _export
    dynamic_axes=dynamic_axes)
  File "/home/appuser/.local/lib/python3.6/site-packages/torch/onnx/utils.py", line 459, in _model_to_graph
    use_new_jit_passes)
  File "/home/appuser/.local/lib/python3.6/site-packages/torch/onnx/utils.py", line 420, in _create_jit_graph
    graph, torch_out = _trace_and_get_graph_from_model(model, args)
  File "/home/appuser/.local/lib/python3.6/site-packages/torch/onnx/utils.py", line 380, in _trace_and_get_graph_from_model
    torch.jit._get_trace_graph(model, args, strict=False, _force_outplace=False, _return_inputs_states=True)
  File "/home/appuser/.local/lib/python3.6/site-packages/torch/jit/_trace.py", line 1139, in _get_trace_graph
    outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
  File "/home/appuser/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/appuser/.local/lib/python3.6/site-packages/torch/jit/_trace.py", line 93, in forward
    in_vars, in_desc = _flatten(args)
RuntimeError: Only tuples, lists and Variables are supported as JIT inputs/outputs. Dictionaries and strings are also accepted, but their usage is not recommended. Here, received an input of unsupported type: int

So I can not understand how to deal with it.