[Updated] problem enumerating nn.ModuleList with torch jit



I successfully traced the model, but looks like I cannot save it.
I thought that if I could do something like model = trace(model, (params) ) then it is ready so saving? Am I wrong?
Torch version is ‘1.0.0.dev20190130’

Here is the trace:

RuntimeError                              Traceback (most recent call last)
<ipython-input-23-4014fec4f00b> in <module>
----> 1 ctc_model.save("ctc_test.ph")

could not export python function call <python_value>. Remove calls to python functions before     export.:
def forward(self, x, x_length):
    h_t, x_length = self.rnn(x, x_length)
                    ~~~~~~~~ <--- HERE

(Michael Suo) #2

You cannot export a model if in contains a Python function call. Is self.rnn() a Python function?

(Ananth KV) #3

I think if self.rnn() is subclassed from torch.jit.ScriptModule(), then it should be possible to trace & save it.

(Michael Suo) #4

Could you provide a script/model we can use the reproduce the problem? It’s hard to say what’s going on here without more information. Thanks!


Thank you very much for your replies!
Actually I think I fixed my original question: self.rnn was a nn.Module and now I also made in ScriptModule, but now I have a new problem. Looks like I cannot loop over nn.ModuleList. I tried to index it in the loop but did not work as well. Is it even possible to use jit for nn.ModuleList?

I omitted some parts for brevity:

class PyramidalRNNENcoder(ScriptModule):

__constants__ = ['num_layers']

def __init__(self, num_mels, encoder_size, num_layers, downsampling=None, dropout=0.0):
    super(PyramidalRNNENcoder, self).__init__()
    self.rnns =nn.ModuleList()
    for i in range(num_layers):
        input_size = num_mels*2 if i == 0 else encoder_size*2
        lstm_i = nn.LSTM(input_size,
                                 hidden_size=encoder_size, bidirectional=True)
    self.num_layers = num_layers
def forward(self, x, x_length):
    batch_size = x.size(0)
    idx = 0
    for rnn in self.rnns:
      ~~~~~~~~~~~~~~~~~~  RuntimeError:  python value of type 'ModuleList' cannot be used as a tuple:
      rnn_result = rnn(data)

(Michael Suo) #6

Try putting "rnns" in the __constants__ attribute. We should work on having a better error msg here