Can't save the LSTM model with torch.jit.script_method

class BidirectionalLSTM(torch.jit.ScriptModule):
    # Inputs hidden units Out
    def __init__(self, nIn, nHidden, nOut):
        super(BidirectionalLSTM, self).__init__()

        self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True)
        self.embedding = nn.Linear(nHidden * 2, nOut)

    @torch.jit.script_method
    def forward(self, input):
        recurrent, _ = self.rnn(input)
        T, b, h = recurrent.size()
        t_rec = recurrent.view(T * b, h)

        output = self.embedding(t_rec)  # [T * b, nOut]
        output = output.view(T, b, -1)

        return output

net = BidirectionalLSTM(256,256, 512)
net.save('model.pt')

got the error:

Traceback (most recent call last):
  File "/home/xxh/Desktop/crnn/models/test.py", line 24, in <module>
    net.save('model.pt')
RuntimeError: 
could not export python function call <python_value>. Remove calls to python functions before export.:
@torch.jit.script_method
def forward(self, input):
    recurrent, _ = self.rnn(input)
                   ~~~~~~~~ <--- HERE
    T, b, h = recurrent.size()
    t_rec = recurrent.view(T * b, h)

    output = self.embedding(t_rec)  # [T * b, nOut]
    output = output.view(T, b, -1)

    return output

I think that maybe the problem is that nn.LSTM and nn.Linear are not traced.

From the docs:

To be able to save a module, it must not make any calls to native python functions. This means that all submodules must be subclasses of ScriptModules as well.

So you have to trace the inner modules as done in the example documentation for Conv2D

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.jit import ScriptModule, script_method, trace

class MyScriptModule(ScriptModule):
    def __init__(self):
        super(MyScriptModule, self).__init__()
        # trace produces a ScriptModule's conv1 and conv2
        self.conv1 = trace(nn.Conv2d(1, 20, 5), torch.rand(1, 1, 16, 16))
        self.conv2 = trace(nn.Conv2d(20, 20, 5), torch.rand(1, 20, 16, 16))

    @script_method
    def forward(self, input):
      input = F.relu(self.conv1(input))
      input = F.relu(self.conv2(input))
      return input

@acobobby You are correct but we have some special infrastructure to support torch.nn Modules without needing to trace them (see builtin functions in the master docs for details), so just assigning them as submodules without tracing them is fine (e.g.
self.conv1 = nn.Conv2d(1, 20, 5)).

I don’t remember if these changes were put in an official release yet though, @XiaXuehai could you try to reproduce your issue on the PyTorch nightly build? Your code snippet runs fine for me on it.

1 Like

@driazati Thanks, PyTorch nightly is fine.

Another question. How to load the static_dict intorch.jit.script_method, the net name is changed, loaded by strict=False is not correct!

Solved. I changed the pretrained model’s layer name one by one.