I am trying to export my pytorch model to onnx. But I keep getting the error
RuntimeError: ONNX export failed: Cannot export individual pack_padded_sequence or pad_packed_sequence; these operations must occur in pairs.
The error comes from the pad_packed_sequence where I instantiate PackedSequence, but I am not sure entirely sure what I am doing wrong. Thank you for any help. my code is posted below with corresponsing dimensions.
class WordAttention(nn.Module):
def __init__(self, word_rnn_size, word_rnn_layers, word_att_size, dropout):
super(WordAttention, self).__init__()
emb_dim = 100
vocab_len = 1000000
print("vocab size: " + str(vocab_len))
self.MAXLEN = 100
embedding = nn.Embedding(self.MAXLEN, emb_dim)
emb_matrix = np.random.rand(vocab_len, emb_dim)
et = torch.tensor(emb_matrix, dtype=torch.float32)
embedding.weight = nn.Parameter(et,requires_grad=False)
self.embeddings = embedding
# Bidirectional word-level RNN
self.word_rnn = nn.GRU(emb_dim, word_rnn_size, num_layers=word_rnn_layers, bidirectional=True,
dropout=dropout, batch_first=True)
self.word_attention = nn.Linear(2 * word_rnn_size, word_att_size)
self.word_context_vector = nn.Linear(word_att_size, 1, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, documents, sentences_per_document, words_per_sentence):
sentences_per_document = sentences_per_document.squeeze(dim=1)
sentences = documents.squeeze()
words_per_sentence = words_per_sentence.squeeze()
sentences = self.dropout(self.embeddings(sentences)) # (n_sentences, word_pad_len, emb_size)
packed_words = pack_padded_sequence(sentences,
lengths=words_per_sentence.tolist(),
batch_first=True,
enforce_sorted=False) # (n_words, word_emb)
packed_words, _ = self.word_rnn(
packed_words) #(n_words, 2 * word_rnn_size)
att_w = self.word_attention(packed_words.data) # (n_words, att_size)
att_w = torch.tanh(att_w) # (n_words, att_size)
att_w = self.word_context_vector(att_w).squeeze(1) # (n_words)
max_value = att_w.max() # scalar, for numerical stability during exponent calculation
att_w = torch.exp(att_w - max_value) # (n_words)
att_w, _ = pad_packed_sequence(PackedSequence(data=att_w,
batch_sizes=packed_words.batch_sizes,
sorted_indices=packed_words.sorted_indices,
unsorted_indices=packed_words.unsorted_indices),
batch_first=True,
total_length=100) # (n_sentences, max(words_per_sentence))
word_alphas = att_w / torch.sum(att_w, dim=1, keepdim=True) # (n_sentences, max(words_per_sentence))
sentences, _ = pad_packed_sequence(packed_words,
batch_first=True,total_length=100) # (n_sentences, max(words_per_sentence), 2 * word_rnn_size)
print (sentences.size())
# Find sentence embeddings
sentences = sentences * word_alphas.unsqueeze(2) # (n_sentences, max(words_per_sentence), 2 * word_rnn_size)
sentences = sentences.sum(dim=1) # (n_sentences, 2 * word_rnn_size)