How to trace LSTM with pack_padded_sequence inputs?

I am trying to use the JIT to get a traced version of an LSTM.

Here is some example code:

import torch
import torch.nn as nn

a = torch.randn(8, 5, 30)  # batch of 8 examples, 5 time steps, 30 features
b = torch.randn(16, 10, 30)  # batch of 16 examples, 10 time steps, 30 features

lstm = torch.jit.trace(nn.LSTM(30, 25, batch_first=True), a)


As expected, the above works fine. However, I want to use pack_padded_sequence:

lengths = torch.randint(low=1, high=b.shape[1], size=(len(b),))
lengths, _ = torch.sort(lengths, descending=True)
b = nn.utils.rnn.pack_padded_sequence(b, lengths, batch_first=True)

This doesn’t work:

>>> lstm(b)

RuntimeError: forward() expected value of type Tensor for argument 'input' in position 0, but instead got value of type PackedSequence.

Okay, the traced LSTM expects a Tensor, not a packed sequence. But tracing the LSTM with a packed sequence input also doesn’t work:

lstm = torch.jit.trace(nn.LSTM(30, 25, batch_first=True), b)

RuntimeError: input must have 3 dimensions, got 2

I realise I could probably put it into a ScriptModule, or trace as part of some larger function, but still: Is it possible to get a static, traced version of an LSTM by itself (or any RNN, for that matter), which can take pack_padded_sequence inputs?