Hi, I would like to use PackedSequence directly in a custom module for nlp task. Specifically, given a padded batch, i want to convert it to a packed sequence, perform some operations on the data
and convert back to padded sequence.
An example recipe is below (modifed the example from https://pytorch.org/docs/stable/jit.html)
my recipe works in python but when using jit.script
, it fails with
ValueError: Compiled functions can't take variable number of arguments or use keyword-only arguments with defaults
I wanted to learn why PackedSequence fails with torch script when used here but works fine when called using wrapper methods in torch.nn.utils.rnn. And how I can fix it.
I will be grateful for any help in understanding this behavior.
Thank you!
import torch
from torch.nn.utils.rnn import PackedSequence
import torch.nn._VF as torch_varfuncs
from torch._jit_internal import Optional
class MyModule(torch.nn.Module):
def __init__(self, N, M):
super(MyModule, self).__init__()
# This parameter will be copied to the new ScriptModule
self.weight = torch.nn.Parameter(torch.rand(N, M))
# When this submodule is used, it will be compiled
self.linear = torch.nn.Linear(N, M)
def _pad(self, data, batch_first, batch_sizes, pad_value, sorted_indices, unsorted_indices):
packed_seq = torch.nn.utils.rnn.PackedSequence(data, batch_sizes, sorted_indices, unsorted_indices)
return torch.nn.utils.rnn.pad_packed_sequence(packed_seq, batch_first, pad_value)
def forward(self, input, data_lengths):
batch_first = True
packed_input = torch.nn.utils.rnn.pack_padded_sequence(input, batch_first=batch_first,lengths=data_lengths, enforce_sorted=False)
output = self.weight.mv(packed_input.data)
# This calls the `forward` method of the `nn.Linear` module, which will
# cause the `self.linear` submodule to be compiled to a `ScriptModule` here
output = self.linear(output)
return self._pad(output, batch_first, packed_input.batch_sizes,-1.0, packed_input.sorted_indices, packed_input.unsorted_indices)
class MyModuleVF(MyModule):
def _pad(self, data, batch_first1: bool, batch_sizes, pad_value: float, sorted_indices: Optional[torch.Tensor], unsorted_indices: Optional[torch.Tensor]):
max_length = batch_sizes.size(0)
padded_output, lengths = torch_varfuncs._pad_packed_sequence(data, batch_sizes, batch_first1, -1.0, max_length)
if sorted_indices is not None:
# had to invert permute specifically as pytorch method was giving errors in jit (arange is returning float type and not long, as expected)
output = torch.empty_like(sorted_indices)
output.scatter_(0, sorted_indices,torch.arange(0, sorted_indices.numel(), device=sorted_indices.device).long())
batch_dim = 0 if batch_first1 else 1
return padded_output.index_select(batch_dim, output), lengths[output]
return padded_output, lengths
test_input = torch.tensor([[1., 2., 3., 4.], [5., 6., -1.0, -1.0],[8, 9, 10, -1.0]], dtype=torch.float)
data_lengths = torch.tensor([4,2,3])
size_ = (test_input > 0).sum()
# works
mm = MyModule(20,size_.item())
result = mm(test_input, data_lengths)
# works
mmvf = MyModuleVF(20,size_.item())
result_vf = mmvf(test_input, data_lengths)
# works
mmvf_s = torch.jit.script(MyModuleVF(20,size_.item()))
result_vf_s = mmvf_s(test_input, data_lengths)
# does not work
mm_s = torch.jit.script(MyModule(20,size_.item()))
result_s = mm(test_input, data_lengths)