Torchscript error

Trying to modify the below code to torch.jit.script - but facing errors…any pointers?

class WrappedLSTM(nn.Module):
    def __init__(self, d_input, d_model, n_layer, batch_first, bidirectional, bias, dropout):
        super(WrappedLSTM, self).__init__()
        self.lstm = nn.LSTM(input_size=d_input, hidden_size=d_model, num_layers=n_layer, batch_first=batch_first,
                        bidirectional=bidirectional, bias=bias, dropout=dropout)

    #def forward(self, input, hx=None):
    #    return self.lstm(input, hx)
    @script_method
    def forward(self, input: torch.Tensor, hx: Optional[Tuple[torch.Tensor, torch.Tensor]] = None):
        # Ensure to use the specific methods of self.lstm (e.g., lstm.forward) if needed
        return self.lstm(input, hx)
class Encoder(ScriptModule):
    def __init__(self, d_input, d_model, n_layer, unidirect=False,
                 dropout=0.2, dropconnect=0., time_ds=1, use_cnn=False, freq_kn=3, freq_std=2, pack=False):
        super().__init__()

        self.time_ds = time_ds

        if use_cnn:
            cnn = [nn.Conv2d(1, 32, kernel_size=(3, freq_kn), stride=(2, freq_std)), nn.ReLU(),
                   nn.Conv2d(32, 32, kernel_size=(3, freq_kn), stride=(2, freq_std)), nn.ReLU()]
            self.cnn = nn.Sequential(*cnn)
            d_input = ((((d_input - freq_kn) // freq_std + 1) - freq_kn) // freq_std + 1) * 32
        else:
            self.cnn = None

        # For TorchScript compatibility, use a ModuleList to store layers
        self.rnn = nn.ModuleList([WrappedLSTM(d_input=d_input, d_model=d_model, n_layer=n_layer, batch_first=True,
                                             bidirectional=(not unidirect), bias=True, dropout=dropout)])

        self.unidirect = unidirect
        self.pack = pack

    @script_method
    def rnn_fwd(self, seq, hid):
        # Note: WrappedLSTM should also be TorchScript-compatible
        seq, hid = self.rnn[0](seq, hid)
        return seq, hid

    @script_method
    def forward(self, seq, hid=None):
        if self.time_ds > 1:
            ds = self.time_ds
            l = ((seq.size(1) - 3) // ds) * ds
            seq = seq[:, :l, :]
            seq = seq.view(seq.size(0), -1, seq.size(2) * ds)

        if self.cnn is not None:
            seq = self.cnn(seq.unsqueeze(1))
            seq = seq.permute(0, 2, 1, 3).contiguous()
            seq = seq.view(seq.size(0), seq.size(1), -1)

        seq, hid = self.rnn_fwd(seq, hid)

        if not self.unidirect:
            hidden_size = seq.size(2) // 2
            seq = seq[:, :, :hidden_size] + seq[:, :, hidden_size:]

        return seq, hid

Error Message:

forward(torch.eval_ipex_cnn_true.WrappedLSTM self, Tensor input, (Tensor, Tensor)? hx=None) → ((Tensor, (Tensor, Tensor))):
Expected a value of type ‘Optional[Tuple[Tensor, Tensor]]’ for argument ‘hx’ but instead found type ‘Tensor (inferred)’.
Inferred the value for argument ‘hx’ to be of type ‘Tensor’ because it was not annotated with an explicit type.
:
File “/home/psakhamo/zoom_quantization/eval_ipex_cnn_true.py”, line 115
def rnn_fwd(self, seq, hid):
# Note: WrappedLSTM should also be TorchScript-compatible
seq, hid = self.rnn[0](seq, hid)
~~~~~~~~~~~ <— HERE
return seq, hid
‘Encoder.rnn_fwd’ is being compiled since it was called from ‘Encoder.forward’
File “/home/psakhamo/zoom_quantization/eval_ipex_cnn_true.py”, line 131
seq = seq.view(seq.size(0), seq.size(1), -1)

    seq, hid = self.rnn_fwd(seq, hid)
                                 ~~~ <--- HERE

    if not self.unidirect:

Able to get rid of above error - with below changes

    @script_method
    def rnn_fwd(self, seq,  hid:Optional[Tuple[torch.Tensor, torch.Tensor]]):
        # Note: WrappedLSTM should also be TorchScript-compatible
        seq, hid = self.rnn[0](seq, hid)
        return seq, hid

    @script_method
    def forward(self, seq,  hid:Optional[Tuple[torch.Tensor, torch.Tensor]] = None):
        if self.time_ds > 1:
            ds = self.time_ds
            l = ((seq.size(1) - 3) // ds) * ds
            seq = seq[:, :l, :]
            seq = seq.view(seq.size(0), -1, seq.size(2) * ds)

Seeing a new one

RuntimeError: Tried to serialize object torch.eval_ipex_cnn_true.Encoder which does not have a getstate method defined!

Added “getstate” implementation to Encoder class as below

    def __getstate__(self):
        state = self.__dict__.copy()
        # Remove the 'pack' key from the state dictionary
        state.pop('pack', None)
        return state

Now - when trying to export/serialize the pt mode

torch.save(model.encoder, torch_enc_model_path)
  • still getting below error

RuntimeError: Tried to serialize object torch.eval_ipex_cnn_true.Encoder which does not have a getstate method defined!