I am trying to use the JIT to get a traced version of an LSTM.
Here is some example code:
import torch import torch.nn as nn a = torch.randn(8, 5, 30) # batch of 8 examples, 5 time steps, 30 features b = torch.randn(16, 10, 30) # batch of 16 examples, 10 time steps, 30 features lstm = torch.jit.trace(nn.LSTM(30, 25, batch_first=True), a) lstm(a) lstm(b)
As expected, the above works fine. However, I want to use
lengths = torch.randint(low=1, high=b.shape, size=(len(b),)) lengths, _ = torch.sort(lengths, descending=True) b = nn.utils.rnn.pack_padded_sequence(b, lengths, batch_first=True)
This doesn’t work:
>>> lstm(b) RuntimeError: forward() expected value of type Tensor for argument 'input' in position 0, but instead got value of type PackedSequence.
Okay, the traced LSTM expects a
Tensor, not a packed sequence. But tracing the LSTM with a packed sequence input also doesn’t work:
lstm = torch.jit.trace(nn.LSTM(30, 25, batch_first=True), b) RuntimeError: input must have 3 dimensions, got 2
I realise I could probably put it into a
ScriptModule, or trace as part of some larger function, but still: Is it possible to get a static, traced version of an LSTM by itself (or any RNN, for that matter), which can take