Jit.script and custom usage of PackedSequence

Hi, I would like to use PackedSequence directly in a custom module for nlp task. Specifically, given a padded batch, i want to convert it to a packed sequence, perform some operations on the data and convert back to padded sequence.

An example recipe is below (modifed the example from https://pytorch.org/docs/stable/jit.html)

my recipe works in python but when using jit.script, it fails with

ValueError: Compiled functions can't take variable number of arguments or use keyword-only arguments with defaults

I wanted to learn why PackedSequence fails with torch script when used here but works fine when called using wrapper methods in torch.nn.utils.rnn. And how I can fix it.

I will be grateful for any help in understanding this behavior.

Thank you!

import torch
from torch.nn.utils.rnn import PackedSequence
import torch.nn._VF as torch_varfuncs
from torch._jit_internal import Optional

class MyModule(torch.nn.Module):
    def __init__(self, N, M):
        super(MyModule, self).__init__()
        # This parameter will be copied to the new ScriptModule
        self.weight = torch.nn.Parameter(torch.rand(N, M))

        # When this submodule is used, it will be compiled
        self.linear = torch.nn.Linear(N, M)

    def _pad(self, data, batch_first, batch_sizes, pad_value, sorted_indices, unsorted_indices):
        packed_seq = torch.nn.utils.rnn.PackedSequence(data, batch_sizes, sorted_indices, unsorted_indices)
        return torch.nn.utils.rnn.pad_packed_sequence(packed_seq, batch_first, pad_value)

    def forward(self, input, data_lengths):
        batch_first = True
        packed_input = torch.nn.utils.rnn.pack_padded_sequence(input, batch_first=batch_first,lengths=data_lengths, enforce_sorted=False)
        output = self.weight.mv(packed_input.data)
        # This calls the `forward` method of the `nn.Linear` module, which will
        # cause the `self.linear` submodule to be compiled to a `ScriptModule` here
        output = self.linear(output)
        return self._pad(output, batch_first, packed_input.batch_sizes,-1.0, packed_input.sorted_indices, packed_input.unsorted_indices)

class MyModuleVF(MyModule):

    def _pad(self, data, batch_first1: bool, batch_sizes, pad_value: float, sorted_indices: Optional[torch.Tensor], unsorted_indices: Optional[torch.Tensor]):

        max_length = batch_sizes.size(0)
        padded_output, lengths = torch_varfuncs._pad_packed_sequence(data, batch_sizes, batch_first1, -1.0, max_length)
        if sorted_indices is not None:
           # had to invert permute specifically as pytorch method was giving errors in jit (arange is returning float type and not long, as expected)
            output = torch.empty_like(sorted_indices)
            output.scatter_(0, sorted_indices,torch.arange(0, sorted_indices.numel(), device=sorted_indices.device).long())
            batch_dim = 0 if batch_first1 else 1
            return padded_output.index_select(batch_dim, output), lengths[output]
        return padded_output, lengths

test_input = torch.tensor([[1., 2., 3., 4.], [5., 6., -1.0, -1.0],[8, 9, 10, -1.0]], dtype=torch.float)
data_lengths = torch.tensor([4,2,3])
size_ = (test_input > 0).sum()
# works
mm = MyModule(20,size_.item())
result = mm(test_input, data_lengths)


# works
mmvf = MyModuleVF(20,size_.item())
result_vf = mmvf(test_input, data_lengths)


# works
mmvf_s = torch.jit.script(MyModuleVF(20,size_.item()))
result_vf_s = mmvf_s(test_input, data_lengths)

# does not work
mm_s = torch.jit.script(MyModule(20,size_.item()))
result_s = mm(test_input, data_lengths)

Thanks for the repro! There is a bug somewhere here, would you mind filing an issue on GitHub?

For some reason it works for me on master if you use PackedSequence directly instead of the qualified version, so that could be a workaround until we get this fixed:

    def _pad(self, data, batch_first: bool, batch_sizes, pad_value: float, sorted_indices: Optional[torch.Tensor], unsorted_indices: Optional[torch.Tensor]):
        packed_seq = PackedSequence(data, batch_sizes, sorted_indices, unsorted_indices)
        return torch.nn.utils.rnn.pad_packed_sequence(packed_seq, batch_first, pad_value)