Hi!
I’m trying to export my model by ONNX. Model example:
class mfm(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, type=1):
super(mfm, self).__init__()
self.out_channels = out_channels
if type == 1:
self.filter = nn.Conv2d(in_channels, 2*out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
else:
self.filter = nn.Linear(in_channels, 2*out_channels)
def forward(self, x):
x = self.filter(x)
out = torch.split(x, self.out_channels, 1)
return torch.max(out[0], out[1])
and found this output with next error:
RuntimeError: ONNX export failed: Couldn't export operator narrow
Graph we tried to export:
graph(%1 : Float(1, 1, 256, 128)
%2 : Float(96, 1, 5, 5)
%3 : Float(96)) {
%5 : UNKNOWN_TYPE = Conv[kernel_shape=[5, 5], strides=[1, 1], pads=[2, 2, 2, 2], dilations=[1, 1], group=1](%1, %2), uses = [[%6.i0]];
%6 : Float(1, 96, 256, 128) = Add[broadcast=1, axis=1](%5, %3), uses = [%7.i0, %8.i0];
%7 : Float(1!, 48, 256, 128) = narrow[dimension=1, start=0, length=48](%6), uses = [%9.i0];
%8 : Float(1!, 48, 256, 128) = narrow[dimension=1, start=48, length=48](%6), uses = [%9.i1];
%9 : Float(1, 48, 256, 128) = max(%7, %8), uses = [%0.i0];
return (%9);
}
I found that split is supported by ONNX on this page, but not narrow.
Could someone suggest some alternatives for split which supported by ONNX? Or some other ways to avoid this?
Thanks!