I am trying to use the JIT to get a traced version of an LSTM.
Here is some example code:
import torch
import torch.nn as nn
a = torch.randn(8, 5, 30) # batch of 8 examples, 5 time steps, 30 features
b = torch.randn(16, 10, 30) # batch of 16 examples, 10 time steps, 30 features
lstm = torch.jit.trace(nn.LSTM(30, 25, batch_first=True), a)
lstm(a)
lstm(b)
As expected, the above works fine. However, I want to use pack_padded_sequence
:
lengths = torch.randint(low=1, high=b.shape[1], size=(len(b),))
lengths, _ = torch.sort(lengths, descending=True)
b = nn.utils.rnn.pack_padded_sequence(b, lengths, batch_first=True)
This doesn’t work:
>>> lstm(b)
RuntimeError: forward() expected value of type Tensor for argument 'input' in position 0, but instead got value of type PackedSequence.
Okay, the traced LSTM expects a Tensor
, not a packed sequence. But tracing the LSTM with a packed sequence input also doesn’t work:
lstm = torch.jit.trace(nn.LSTM(30, 25, batch_first=True), b)
RuntimeError: input must have 3 dimensions, got 2
I realise I could probably put it into a ScriptModule
, or trace as part of some larger function, but still: Is it possible to get a static, traced version of an LSTM by itself (or any RNN, for that matter), which can take pack_padded_sequence
inputs?