I’m trying to convert a (fairly simple) 1D depthwise-separable resnet to ONNX. However, when calling torch.onnx.export
, I’m getting an UnsupportedOperatorError:
torch.onnx.symbolic_registry.UnsupportedOperatorError: Exporting the operator ::_convolution_mode to ONNX opset version 13 is not supported. Please feel free to request support or submit a pull request on PyTorch GitHub.
I don’t see any obvious operators that shouldn’t be supported. Code to reproduce is provided below:
import torch
import torch.nn as nn
class DepthwiseSeparable1D(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, bias=False, padding='valid'):
super(DepthwiseSeparable1D, self).__init__()
self.dw = nn.Conv1d(
in_channels=in_channels,
out_channels=in_channels,
groups=in_channels,
kernel_size=kernel_size,
bias=bias,
stride=stride,
padding=padding
)
self.pw = nn.Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
bias=bias,
padding=padding,
)
def forward(self, x):
x = self.dw(x)
x = self.pw(x)
return x
class ResidualBlock1D(nn.Module):
def __init__(self, sub_blocks, filters, kernel_size, stride=1, padding='valid'):
super(ResidualBlock1D, self).__init__()
self.sub_blocks = sub_blocks
self.padding = padding
self.conv1a = nn.Conv1d(
in_channels=225, out_channels=filters,
kernel_size=1, stride=stride, bias=False
)
self.bn1a = nn.BatchNorm1d(filters)
self.block_conv = nn.Sequential()
self.bn = nn.Sequential()
for i in range(self.sub_blocks):
self.block_conv.append(
DepthwiseSeparable1D(
in_channels=filters,
out_channels=filters,
kernel_size=kernel_size,
stride=stride if i == 0 else 1,
padding=padding,
bias=False,
)
)
self.bn.append(nn.BatchNorm1d(filters))
def forward(self, inputs):
res = self.conv1a(inputs)
res = self.bn1a(res)
output = inputs
for i in range(self.sub_blocks):
output = self.block_conv[i](output)
output = self.bn[i](output)
if i == self.sub_blocks - 1:
if self.padding == 'valid':
res = res[:, :, -output.shape[-1]:]
output = res + output
output = nn.ReLU(inplace=True)(output)
return output
class SeparableResNet1D(nn.Sequential):
def __init__(
self,
blocks,
sub_blocks,
filters,
kernel_size,
strides,
last_layer_filters = 1024,
padding='valid',
):
super(SeparableResNet1D, self).__init__()
self.padding = padding
if not strides:
strides = [1]
strides += [1 for _ in range(blocks + 1 - len(strides))]
if isinstance(kernel_size, int):
kernel_size = [kernel_size for _ in range(blocks + 1)]
# Depthwise conv
self.append(
DepthwiseSeparable1D(
in_channels=64,
out_channels=filters,
bias=False,
kernel_size=kernel_size[0],
stride=strides[0],
padding=self.padding,
)
)
for k in range(blocks):
self.append(
ResidualBlock1D(
sub_blocks=sub_blocks,
filters=filters,
kernel_size=kernel_size[k+1],
padding=padding,
stride=strides[k+1]
)
)
self.append(nn.Conv1d(
in_channels=filters,
out_channels=last_layer_filters,
kernel_size=1,
bias=False,
padding='valid'
))
self.append(
nn.BatchNorm1d(num_features=last_layer_filters)
)
self.append(
nn.ReLU()
)
class MyModel(nn.Module):
def __init__(self, encoder, last_layer_filters, num_classes=2):
super(MyModel, self).__init__()
self.encoder = encoder
self.linear_model = nn.Sequential(
nn.Conv1d(in_channels=last_layer_filters, out_channels=2, kernel_size=1), nn.Softmax(dim=1)
)
def forward(self, inputs):
out = self.linear_model(self.encoder(inputs))
return out
last_layer_filters = 1024
encoder = SeparableResNet1D(
blocks=6, sub_blocks=3, filters=225, kernel_size=3, strides=[3], padding='valid',
last_layer_filters=last_layer_filters
)
model = MyModel(
encoder, last_layer_filters
)
model.eval()
torch.save(model.state_dict(), "toy_resnet_L.pth")
dummy_input = torch.randn([2, 64, 111])
export_model = MyModel(
encoder, last_layer_filters
)
export_model.load_state_dict(torch.load("toy_resnet_L.pth"))
export_model.eval()
out = export_model(dummy_input)
torch.onnx.export(
export_model, dummy_input, "resnet_L.random.onnx",
verbose=True,
)