Onnx error with 'Unsupported index type <class 'slice'>'

Hi,

I got an error message when exporting the following model with onnx. It looks like onnx does not support slice operator. In the model definition, I just want the last output of lstm layer. So I use slice operator, i.e., ‘output[:, -1, :]’. In addition, the ‘view’ method will get the same error with onnx exporting.

  1. The model definition:

class Model1(torch.nn.Module):
    def __init__(self, use_cuda, input_size, hidden_size, output_size):
        super(Model1, self).__init__()
        self.model_name = 'model1'
        self.use_cuda = use_cuda
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.hidden_layer = torch.nn.LSTM(input_size, hidden_size)
        self.output_layer = torch.nn.Linear(hidden_size, output_size)


    def dummy_features(self, batch_size):
        seq_tensor = Variable(torch.randn(batch_size, MAX_SEQ_LEN, self.input_size), requires_grad=True).float()
        seq_lengths = Variable(torch.randn(batch_size), requires_grad=True).long()
        return seq_tensor, seq_lengths


    def forward(self, seq_tensor, seq_lengths):
        batch_size = seq_tensor.shape[0]

        h0 = Variable(torch.rand(1, batch_size, self.hidden_size), requires_grad=False)
        c0 = Variable(torch.rand(1, batch_size, self.hidden_size), requires_grad=False)

        if self.use_cuda:
            h0 = h0.cuda()
            c0 = c0.cuda()

        inputs = pack_padded_sequence(seq_tensor, seq_lengths.cpu().data.numpy(), batch_first=True)
        output, (hn, cn) = self.hidden_layer(inputs, (h0, c0))
        output, _ = pad_packed_sequence(output, batch_first=True)
        output = self.output_layer(output[:, -1, :])
        return output

  1. The model graph:
graph(%1 : Float(1, 50, 256)
      %2 : Long(1)
      %3 : Float(64, 256)
      %4 : Float(64, 16)
      %5 : Float(64)
      %6 : Float(64)
      %7 : Float(2, 16)
      %8 : Float(2)) {
  %9 : Float(50!, 1!, 256) = Transpose[perm=[1, 0, 2]](%1), uses = [%10.i0];
  %10 : Float(1!, 1!, 256) = Slice[axes=[0, 1], ends=[1, 1], starts=[0, 0]](%9), uses = [%11.i0];
  %11 : Float(1, 256) = Reshape[shape=[-1, 256]](%10), uses = [%12.i0];
  %12 : Float(1, 256) = Concat[axis=0](%11), uses = [];
  return ();
}
  1. The export code:
onnx_features = cpu_model.dummy_features(onnx_batch_size)
torch.onnx.export(cpu_model, onnx_features, os.path.join(model_save_path, 'model.onnx'), export_params=True)
  1. The error message:
  File "/home/abc/platform/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/onnx/__init__.py", line 75, in export
    _export(model, args, f, export_params, verbose, training)
  File "/home/abc/platform/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/onnx/__init__.py", line 122, in _export
    _optimize_trace(trace)
  File "/home/abc/platform/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/onnx/__init__.py", line 81, in _optimize_trace
    torch._C._jit_pass_onnx(trace)
  File "/home/abc/platform/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/onnx/__init__.py", line 148, in _run_symbolic_method
    return symbolic_fn(*args)
  File "/home/abc/platform/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/autograd/_functions/tensor.py", line 77, in symbolic
    raise ValueError('Unsupported index type {}'.format(type(index)))
ValueError: Unsupported index type <class 'slice'>
  1. An additional operator with the same error message (I want all the outputs of lstm and reshaped):
output.contiguous().view(batch_size, -1)
1 Like