[ONNX] Export model with pack_padded_sequence to ONNX fails

Hi all! Thank you for your time helping me out on this problem. I am trying to export my module to ONNX, but when my modules include nn.utils.rnn.pack_padded_sequence anywhere, it fails to export to ONNX, with the error:

RuntimeError: ONNX export failed: Cannot export individual pack_padded_sequence or pad_packed_sequence; these operations must occur in pairs.

My model only needs to use one of the two… Also, aren’t these inverses of each other? So using them in pairs would just give you back your original tensor, right…?

Here is the code needed to reproduce the problem. I create a sample net that uses the same features as my model in question, with a pack_padded_sequence included. I create test inputs, and run them through the model with no problem to make a test output. I then attempt to export the model to ONNX, and it fails, due to the pack_padded_sequence.

import torch
from torch import nn

class SampleNet(nn.Module):
    def __init__(self):

        super().__init__()

        self.BatchNorm1 = nn.BatchNorm1d(5)

    def forward(self, X, X_len):
        # Mask
        X_pack = torch.nn.utils.rnn.pack_padded_sequence(X, X_len, batch_first=True, enforce_sorted=False)

        # norm
        X_ln = torch.nn.utils.rnn.PackedSequence(
            data=self.BatchNorm1(X_pack.data),
            batch_sizes=X_pack.batch_sizes,
            sorted_indices=X_pack.sorted_indices,
            unsorted_indices=X_pack.unsorted_indices,
        )
        
        return X_ln

example_input = torch.randn(1, 30, 5, requires_grad=True)
example_input_length = torch.abs(torch.randn(1, requires_grad=True)) + 1

print("Example input: %s" % example_input)
print("Example input lengths: %s" % list(example_input.size()))

net = SampleNet() # Model instantiation
net.eval()

example_output = net(example_input, example_input_length) # Check that forward pass works as expected:
print(f"Output from Net (works): {example_output}")


torch.onnx.export(model=net,
                  args=(example_input, example_input_length),
                  example_outputs=example_output,
                  f="test_pack_unpack.onnx",
                  verbose=True,
                  input_names=["x", "x_length"],
                  output_names=["preds"]
                 )# Export to ONNX (fails)

And here are the resulting error messages:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-3-8120f5b33ec5> in <module>
     39 
     40 # Export to ONNX (fails):
---> 41 torch.onnx.export(model=net,
     42                   args=(example_input, example_input_length),
     43                   example_outputs=example_output,

~/.local/share/virtualenvs/z-jkfIRWOP/lib/python3.8/site-packages/torch/onnx/__init__.py in export(model, args, f, export_params, verbose, training, input_names, output_names, aten, export_raw_ir, operator_export_type, opset_version, _retain_param_name, do_constant_folding, example_outputs, strip_doc_string, dynamic_axes, keep_initializers_as_inputs, custom_opsets, enable_onnx_checker, use_external_data_format)
    161 
    162     from torch.onnx import utils
--> 163     return utils.export(model, args, f, export_params, verbose, training,
    164                         input_names, output_names, aten, export_raw_ir,
    165                         operator_export_type, opset_version, _retain_param_name,

~/.local/share/virtualenvs/z-jkfIRWOP/lib/python3.8/site-packages/torch/onnx/utils.py in export(model, args, f, export_params, verbose, training, input_names, output_names, aten, export_raw_ir, operator_export_type, opset_version, _retain_param_name, do_constant_folding, example_outputs, strip_doc_string, dynamic_axes, keep_initializers_as_inputs, custom_opsets, enable_onnx_checker, use_external_data_format)
     61         else:
     62             operator_export_type = OperatorExportTypes.ONNX
---> 63     _export(model, args, f, export_params, verbose, training, input_names, output_names,
     64             operator_export_type=operator_export_type, opset_version=opset_version,
     65             _retain_param_name=_retain_param_name, do_constant_folding=do_constant_folding,

~/.local/share/virtualenvs/z-jkfIRWOP/lib/python3.8/site-packages/torch/onnx/utils.py in _export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, export_type, example_outputs, propagate, opset_version, _retain_param_name, do_constant_folding, strip_doc_string, dynamic_axes, keep_initializers_as_inputs, fixed_batch_size, custom_opsets, add_node_names, enable_onnx_checker, use_external_data_format)
    498 
    499         if export_params:
--> 500             proto, export_map = graph._export_onnx(
    501                 params_dict, opset_version, dynamic_axes, defer_weight_export,
    502                 operator_export_type, strip_doc_string, val_keep_init_as_ip, custom_opsets,

RuntimeError: ONNX export failed: Cannot export individual pack_padded_sequence or pad_packed_sequence; these operations must occur in pairs.

Usage of this operation occurred at:
/home/z/.local/share/virtualenvs/z-jkfIRWOP/lib/python3.8/site-packages/torch/nn/utils/rnn.py(244): pack_padded_sequence
<ipython-input-3-8120f5b33ec5>(13): forward
/home/z/.local/share/virtualenvs/z-jkfIRWOP/lib/python3.8/site-packages/torch/nn/modules/module.py(534): _slow_forward
/home/z/.local/share/virtualenvs/z-jkfIRWOP/lib/python3.8/site-packages/torch/nn/modules/module.py(548): __call__
/home/z/.local/share/virtualenvs/z-jkfIRWOP/lib/python3.8/site-packages/torch/jit/__init__.py(348): wrapper
/home/z/.local/share/virtualenvs/z-jkfIRWOP/lib/python3.8/site-packages/torch/jit/__init__.py(357): forward
/home/z/.local/share/virtualenvs/z-jkfIRWOP/lib/python3.8/site-packages/torch/nn/modules/module.py(550): __call__
/home/z/.local/share/virtualenvs/z-jkfIRWOP/lib/python3.8/site-packages/torch/jit/__init__.py(278): _get_trace_graph
/home/z/.local/share/virtualenvs/z-jkfIRWOP/lib/python3.8/site-packages/torch/onnx/utils.py(291): _trace_and_get_graph_from_model
/home/z/.local/share/virtualenvs/z-jkfIRWOP/lib/python3.8/site-packages/torch/onnx/utils.py(334): _model_to_graph
/home/z/.local/share/virtualenvs/z-jkfIRWOP/lib/python3.8/site-packages/torch/onnx/utils.py(483): _export
/home/z/.local/share/virtualenvs/z-jkfIRWOP/lib/python3.8/site-packages/torch/onnx/utils.py(63): export
/home/z/.local/share/virtualenvs/z-jkfIRWOP/lib/python3.8/site-packages/torch/onnx/__init__.py(163): export
<ipython-input-3-8120f5b33ec5>(41): <module>
/home/z/.local/share/virtualenvs/z-jkfIRWOP/lib/python3.8/site-packages/IPython/core/interactiveshell.py(3331): run_code
/home/z/.local/share/virtualenvs/z-jkfIRWOP/lib/python3.8/site-packages/IPython/core/interactiveshell.py(3254): run_ast_nodes
/home/z/.local/share/virtualenvs/z-jkfIRWOP/lib/python3.8/site-packages/IPython/core/interactiveshell.py(3062): run_cell_async
/home/z/.local/share/virtualenvs/z-jkfIRWOP/lib/python3.8/site-packages/IPython/core/async_helpers.py(68): _pseudo_sync_runner
/home/z/.local/share/virtualenvs/z-jkfIRWOP/lib/python3.8/site-packages/IPython/core/interactiveshell.py(2886): _run_cell
/home/z/.local/share/virtualenvs/z-jkfIRWOP/lib/python3.8/site-packages/IPython/core/interactiveshell.py(2857): run_cell
/home/z/.local/share/virtualenvs/z-jkfIRWOP/lib/python3.8/site-packages/ipykernel/zmqshell.py(536): run_cell
/home/z/.local/share/virtualenvs/z-jkfIRWOP/lib/python3.8/site-packages/ipykernel/ipkernel.py(300): do_execute
/home/z/.local/share/virtualenvs/z-jkfIRWOP/lib/python3.8/site-packages/tornado/gen.py(209): wrapper
/home/z/.local/share/virtualenvs/z-jkfIRWOP/lib/python3.8/site-packages/ipykernel/kernelbase.py(543): execute_request
/home/z/.local/share/virtualenvs/z-jkfIRWOP/lib/python3.8/site-packages/tornado/gen.py(209): wrapper
/home/z/.local/share/virtualenvs/z-jkfIRWOP/lib/python3.8/site-packages/ipykernel/kernelbase.py(268): dispatch_shell
/home/z/.local/share/virtualenvs/z-jkfIRWOP/lib/python3.8/site-packages/tornado/gen.py(209): wrapper
/home/z/.local/share/virtualenvs/z-jkfIRWOP/lib/python3.8/site-packages/ipykernel/kernelbase.py(365): process_one
/home/z/.local/share/virtualenvs/z-jkfIRWOP/lib/python3.8/site-packages/tornado/gen.py(748): run
/home/z/.local/share/virtualenvs/z-jkfIRWOP/lib/python3.8/site-packages/tornado/gen.py(787): inner
/home/z/.local/share/virtualenvs/z-jkfIRWOP/lib/python3.8/site-packages/tornado/ioloop.py(743): _run_callback
/home/z/.local/share/virtualenvs/z-jkfIRWOP/lib/python3.8/site-packages/tornado/ioloop.py(690): <lambda>
/usr/lib/python3.8/asyncio/events.py(81): _run
/usr/lib/python3.8/asyncio/base_events.py(1844): _run_once
/usr/lib/python3.8/asyncio/base_events.py(563): run_forever
/home/z/.local/share/virtualenvs/z-jkfIRWOP/lib/python3.8/site-packages/tornado/platform/asyncio.py(149): start
/home/z/.local/share/virtualenvs/z-jkfIRWOP/lib/python3.8/site-packages/ipykernel/kernelapp.py(597): start
/home/z/.local/share/virtualenvs/z-jkfIRWOP/lib/python3.8/site-packages/traitlets/config/application.py(664): launch_instance
/home/z/.local/share/virtualenvs/z-jkfIRWOP/lib/python3.8/site-packages/ipykernel_launcher.py(16): <module>
/usr/lib/python3.8/runpy.py(85): _run_code
/usr/lib/python3.8/runpy.py(192): _run_module_as_main


Graph we tried to export:
graph(%x : Float(1, 30, 5),
      %x_length : Float(1)):
  %7 : Long(1) = onnx::Cast[to=7](%x_length) # /home/z/.local/share/virtualenvs/z-jkfIRWOP/lib/python3.8/site-packages/torch/nn/utils/rnn.py:234:0
  %8 : Long(1), %9 : Long(1) = onnx::TopK[axis=-1, k=1](%7) # /home/z/.local/share/virtualenvs/z-jkfIRWOP/lib/python3.8/site-packages/torch/nn/utils/rnn.py:238:0
  %10 : Long(1) = onnx::Cast[to=7](%9) # /home/z/.local/share/virtualenvs/z-jkfIRWOP/lib/python3.8/site-packages/torch/nn/utils/rnn.py:239:0
  %11 : Float(1, 30, 5) = onnx::Gather[axis=0](%x, %10) # /home/z/.local/share/virtualenvs/z-jkfIRWOP/lib/python3.8/site-packages/torch/nn/utils/rnn.py:241:0
  %12 : Tensor = onnx::Transpose[perm=[1, 0, 2]](%11)
  %13 : Tensor = onnx::Cast[to=6](%8)
  %preds : Float(1, 5), %15 : Long(1) = prim::PackPadded(%12, %13) # /home/z/.local/share/virtualenvs/z-jkfIRWOP/lib/python3.8/site-packages/torch/nn/utils/rnn.py:244:0
  %16 : Tensor = onnx::Shape(%10)
  %17 : Long(1) = onnx::ConstantOfShape[value={0}](%16) # /home/z/.local/share/virtualenvs/z-jkfIRWOP/lib/python3.8/site-packages/torch/nn/utils/rnn.py:190:0
  %18 : Long(1) = onnx::Constant[value={0}]()
  %19 : Long(1) = onnx::Scatter[axis=0](%17, %10, %18) # /home/z/.local/share/virtualenvs/z-jkfIRWOP/lib/python3.8/site-packages/torch/nn/utils/rnn.py:191:0
  return (%preds, %15, %10, %19)

Also, here is my environment if this is useful:

PyTorch version: 1.5.1
Is debug build: False
CUDA used to build PyTorch: 10.2
ROCM used to build PyTorch: N/A

OS: Ubuntu 18.04.5 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: Could not collect
CMake version: Could not collect

Python version: 3.8 (64-bit runtime)
Is CUDA available: True
CUDA runtime version: Could not collect
GPU models and configuration:
GPU 0: TITAN RTX
GPU 1: TITAN RTX
GPU 2: TITAN RTX
GPU 3: TITAN RTX

Nvidia driver version: 450.36.06
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] guided-filter-pytorch==3.7.5
[pip3] numpy==1.18.5
[pip3] pytorch-lightning==0.9.0
[pip3] torch==1.5.1
[pip3] torchsummary==1.5.1
[pip3] torchvision==0.6.1

To work around this, I built my own pack_padded_sequence function like so:

def my_pack_padded_sequence(input, lengths, batch_first=False, enforce_sorted=True):

    lengths = torch.as_tensor(lengths, dtype=torch.int64)

    lengths, sorted_indices = torch.sort(lengths, descending=True)
    sorted_indices = sorted_indices.to(input.device)
    batch_dim = 0 if batch_first else 1
    input = input.index_select(batch_dim, sorted_indices)
    
    data, batch_sizes = torch._VF._pack_padded_sequence(input, lengths, batch_first)
    
    unsorted_indices = torch.empty_like(sorted_indices, memory_format=torch.legacy_contiguous_format)
    unsorted_indices.scatter_(0, sorted_indices, torch.arange(0, sorted_indices.numel(), device=sorted_indices.device))
    
    return data, batch_sizes, sorted_indices, unsorted_indices

But, I can’t seem to find what the call torch._VF._pack_padded_sequence(input, lengths, batch_first) is actually doing, I can’t find it anywhere in the source code! I’m guessing it uses C functionality? Anyway, with this, ONNX still complains about the line in question.

Is there any way to fix this issue? Perhaps some way to know the behavior of torch._VF._pack_padded_sequence and write it myself? It seems as though there are open issues on this that are year+ old, but absolutely no response from torch.

Thank you so much in advance for your time!