Hi,
I created an issue at:
However, I am thinking maybe here is the right place to ask for help. Thus I copied my ask here:
This is regarding sample at:
https://pytorch.org/blog/the-road-to-1_0/:
from torch.jit import script
@script
def rnn_loop(x):
hidden = None
for x_t in x.split(1):
x, hidden = model(x, hidden)
return x
I cannot make it to work. Here is my code:
import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
def test_ScriptModelRNN():
class SimpleRNNCell(nn.Module):
def **init** (self, input_size, hidden_size):
super(SimpleRNNCell, self). **init** ()
self.linear_h = nn.Linear(input_size, hidden_size)
def forward(self, inp, h_0):
h = self.linear_h(inp)
return h + h_0, h
with torch.no_grad():
sequence_len, input_size, hidden_size = 4, 3, 2
model = SimpleRNNCell(input_size, hidden_size)
hidden = torch.zeros(1, hidden_size)
# # test cell
# cell_input = torch.randn(input_size)
# cell_output, hidden = model(cell_input, hidden)
# import pdb; pdb.set_trace()
# #
@torch.jit.script
def rnn_loop(x):
hidden = None
for x_t in x.split(1):
x, hidden = model(x_t, hidden)
return x
input = torch.randn(sequence_len, input_size)
output = rnn_loop(input)
I am getting:
Exception has occurred: RuntimeError
for operator (Tensor 0, Tensor 1) -> (Tensor, Tensor):
expected a value of type Tensor for argument '1' but found Tensor?
@torch.jit.script
def rnn_loop(x):
hidden = None
for x_t in x.split(1):
x, hidden = model(x_t, hidden)
~~~~~~ <--- HERE
return x
:
@torch.jit.script
def rnn_loop(x):
hidden = None
for x_t in x.split(1):
x, hidden = model(x_t, hidden)
~~~~~ <--- HERE
return x
File "/home/liqun/pytorch/torch/jit/ **init** .py", line 751, in script
_jit_script_compile(mod, ast, _rcb, get_default_args(obj))
File "/home/liqun/Untitled Folder/test_onnx_export.py", line 218, in test_ScriptModelRNN
@torch.jit.script
File "/home/liqun/Untitled Folder/test_onnx_export.py", line 282, in
test_ScriptModelRNN()
File "/home/liqun/.conda/envs/py36/lib/python3.6/runpy.py", line 85, in _run_code
exec(code, run_globals)
File "/home/liqun/.conda/envs/py36/lib/python3.6/runpy.py", line 96, in _run_module_code
mod_name, mod_spec, pkg_name, script_name)
File "/home/liqun/.conda/envs/py36/lib/python3.6/runpy.py", line 263, in run_path
pkg_name=pkg_name, script_name=fname)