Hi I’m trying to export PyTorch custom layer to to onnx and got this error:
Only tuples, lists and Variables are supported as JIT inputs/outputs. Dictionaries and strings are also accepted, but their usage is not recommended. Here, received an input of unsupported type: int
I checked same issues and tried to resolve:
But cannot do it.
There is my code:
import torch
import dcn_op_v2
from torch.autograd import Function
from torch.nn import Module
import torch.nn as nn
import math
class DeformableConv2DFunction(Function):
@staticmethod
def forward(ctx, input_tensor, weight, bias, offset, mask, stride, pad, dilation, deformable_groups):
ctx.stride_h = stride[0]
ctx.stride_w = stride[1]
ctx.pad_h = pad[0]
ctx.pad_w = pad[1]
ctx.dilation_h = dilation[0]
ctx.dilation_w = dilation[1]
ctx.deformable_groups = deformable_groups
output = dcn_op_v2.forward(
input_tensor,
weight,
bias,
offset,
mask,
ctx.stride_h, ctx.stride_w,
ctx.pad_h, ctx.pad_w,
ctx.dilation_h, ctx.dilation_w,
ctx.deformable_groups
)
ctx.save_for_backward(input_tensor, weight, offset, mask, bias)
return output
@staticmethod
def backward(ctx, *grad_outputs):
input_tensor, weight, offset, mask, bias = ctx.saved_tensors
grad_input, grad_weight, grad_bias, grad_offset, grad_mask = dcn_op_v2.backward(
input_tensor,
weight,
bias,
offset,
mask,
grad_outputs[0],
ctx.stride_h, ctx.stride_w,
ctx.pad_h, ctx.pad_w,
ctx.dilation_h, ctx.dilation_w,
ctx.deformable_groups
)
return grad_input, grad_weight, grad_bias, grad_offset, grad_mask, \
None, None, None, None
class DeformableConv2DLayer(Module):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
deformable_groups):
super().__init__()
def typexam(x):
if type(x)==int:
return (x, x)
elif type(x)==tuple and len(x)==2:
return x
else:
raise TypeError
kernel_size = typexam(kernel_size)
stride = typexam(stride)
padding = typexam(padding)
dilation = typexam(dilation)
self.stride = stride
self.pad = padding
self.dilation = dilation
self.deformable_groups = deformable_groups
self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels, *kernel_size))
self.bias = nn.Parameter(torch.Tensor(out_channels))
self.conv_offset_mask = nn.Conv2d(in_channels,
self.deformable_groups * 3 * kernel_size[0] * kernel_size[1],
kernel_size=kernel_size,
stride=self.stride,
padding=self.pad,
bias=True)
self.reset_parameters(in_channels, kernel_size)
self.init_offset()
def init_offset(self):
self.conv_offset_mask.weight.data.zero_()
self.conv_offset_mask.bias.data.zero_()
def reset_parameters(self, in_channels, kernel_size):
n = in_channels
for k in kernel_size:
n *= k
stdv = 1. / math.sqrt(n)
self.weight.data.uniform_(-stdv, stdv)
self.bias.data.zero_()
def forward(self, inputs):
out = self.conv_offset_mask(inputs)
o1, o2, mask = torch.chunk(out, 3, dim=1)
offset = torch.cat((o1, o2), dim=1)
mask = torch.sigmoid(mask)
return DeformableConv2DFunction.apply(inputs, self.weight, self.bias, offset, mask, self.stride, self.pad, self.dilation, self.deformable_groups)
from torch.onnx import register_custom_op_symbolic
def dcn_v2(g, self):
return g.op("ai.onnx.contrib::Dcn_v2", self)
register_custom_op_symbolic('::dcn_v2', dcn_v2, 1)
import io
import onnx
N, inC, inH, inW = 2, 2, 4, 4
outC = 2
kH, kW = 3, 3
stride = (1,1)
padding = (1,1)
dilation= (1,1)
deformable_groups = 1
t = torch.randn(N, inC, inH, inW)
offset = torch.randn(N, deformable_groups * 2 * kW * kH, inH, inW)
mask = torch.rand(N, deformable_groups * 1 * kW * kH, inH, inW)
# mask.data.zero_()
mask = torch.sigmoid(mask)
weight = torch.randn(outC, inC, kH, kW)
bias = torch.rand(outC)
# Export model to ONNX
t_model = DeformableConv2DLayer(inC, outC, (kH,kW), stride, padding, dilation, deformable_groups)
input_nn = (t, weight, bias, offset, mask, stride, padding, dilation, torch.tensor(deformable_groups, dtype=torch.int32))
with torch.no_grad():
torch.onnx.export(t_model, input_nn, 'models/dcn.onnx', export_params=True)
Full ouput of error is:
Traceback (most recent call last):
File "custom_onnx_op.py", line 158, in <module>
torch.onnx.export(t_model, input_nn, 'models/dcn.onnx', export_params=True)
File "/home/appuser/.local/lib/python3.6/site-packages/torch/onnx/__init__.py", line 276, in export
custom_opsets, enable_onnx_checker, use_external_data_format)
File "/home/appuser/.local/lib/python3.6/site-packages/torch/onnx/utils.py", line 94, in export
use_external_data_format=use_external_data_format)
File "/home/appuser/.local/lib/python3.6/site-packages/torch/onnx/utils.py", line 701, in _export
dynamic_axes=dynamic_axes)
File "/home/appuser/.local/lib/python3.6/site-packages/torch/onnx/utils.py", line 459, in _model_to_graph
use_new_jit_passes)
File "/home/appuser/.local/lib/python3.6/site-packages/torch/onnx/utils.py", line 420, in _create_jit_graph
graph, torch_out = _trace_and_get_graph_from_model(model, args)
File "/home/appuser/.local/lib/python3.6/site-packages/torch/onnx/utils.py", line 380, in _trace_and_get_graph_from_model
torch.jit._get_trace_graph(model, args, strict=False, _force_outplace=False, _return_inputs_states=True)
File "/home/appuser/.local/lib/python3.6/site-packages/torch/jit/_trace.py", line 1139, in _get_trace_graph
outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
File "/home/appuser/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/appuser/.local/lib/python3.6/site-packages/torch/jit/_trace.py", line 93, in forward
in_vars, in_desc = _flatten(args)
RuntimeError: Only tuples, lists and Variables are supported as JIT inputs/outputs. Dictionaries and strings are also accepted, but their usage is not recommended. Here, received an input of unsupported type: int
So I can not understand how to deal with it.