Hi,
Currently i have this error when using chunk
on forward
after first iteration. but if chunk
is replace with split
, the error not happen. is there any fix related these problem?
is split
same as chunk
on term of backward?
this is minimal code i can reproduce
import torch
from torch import jit
from torch.nn import Parameter
class CustomRNNCell(jit.ScriptModule):
def __init__(self, input_size, hidden_size):
super(CustomRNNCell, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.weight_ih = Parameter(torch.randn(3 * hidden_size, input_size))
self.weight_hh = Parameter(torch.randn(3 * hidden_size, hidden_size))
self.bias_ih = Parameter(torch.randn(3 * hidden_size))
self.bias_hh = Parameter(torch.randn(3 * hidden_size))
@jit.script_method
def forward(self, input, state):
# type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
hx, cx = state
gates = (torch.mm(input, self.weight_ih.t()) + self.bias_ih +
torch.mm(hx, self.weight_hh.t()) + self.bias_hh)
m, o, i = gates.chunk(3, 1)
m = torch.sigmoid(m)
o = torch.tanh(o)
i = torch.tanh(i)
cy = (1 - m) * cx + (m * i)
hy = (1 - o) * i + (o * cx)
return hy, (hy, cy)
torch.autograd.set_detect_anomaly(mode=True)
cell = CustomRNNCell(
input_size=1280,
hidden_size=256
)
for i in range(20):
x = torch.randn(8, 1280)
state = (
torch.zeros(8, 256),
torch.zeros(8, 256)
)
out, _ = cell(x, state)
print(i)
out.mean().backward()
and the error message
0
1
/usr/local/lib/python3.7/dist-packages/torch/autograd/__init__.py:147: UserWarning: Error detected in torch::jit::(anonymous namespace)::DifferentiableGraphBackward. Traceback of forward call that caused the error:
File "/usr/lib/python3.7/runpy.py", line 193, in _run_module_as_main
"__main__", mod_spec)
File "/usr/lib/python3.7/runpy.py", line 85, in _run_code
exec(code, run_globals)
File "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py", line 16, in <module>
app.launch_new_instance()
File "/usr/local/lib/python3.7/dist-packages/traitlets/config/application.py", line 845, in launch_instance
app.start()
File "/usr/local/lib/python3.7/dist-packages/ipykernel/kernelapp.py", line 499, in start
self.io_loop.start()
File "/usr/local/lib/python3.7/dist-packages/tornado/platform/asyncio.py", line 132, in start
self.asyncio_loop.run_forever()
File "/usr/lib/python3.7/asyncio/base_events.py", line 541, in run_forever
self._run_once()
File "/usr/lib/python3.7/asyncio/base_events.py", line 1786, in _run_once
handle._run()
File "/usr/lib/python3.7/asyncio/events.py", line 88, in _run
self._context.run(self._callback, *self._args)
File "/usr/local/lib/python3.7/dist-packages/tornado/platform/asyncio.py", line 122, in _handle_events
handler_func(fileobj, events)
File "/usr/local/lib/python3.7/dist-packages/tornado/stack_context.py", line 300, in null_wrapper
return fn(*args, **kwargs)
File "/usr/local/lib/python3.7/dist-packages/zmq/eventloop/zmqstream.py", line 451, in _handle_events
self._handle_recv()
File "/usr/local/lib/python3.7/dist-packages/zmq/eventloop/zmqstream.py", line 480, in _handle_recv
self._run_callback(callback, msg)
File "/usr/local/lib/python3.7/dist-packages/zmq/eventloop/zmqstream.py", line 434, in _run_callback
callback(*args, **kwargs)
File "/usr/local/lib/python3.7/dist-packages/tornado/stack_context.py", line 300, in null_wrapper
return fn(*args, **kwargs)
File "/usr/local/lib/python3.7/dist-packages/ipykernel/kernelbase.py", line 283, in dispatcher
return self.dispatch_shell(stream, msg)
File "/usr/local/lib/python3.7/dist-packages/ipykernel/kernelbase.py", line 233, in dispatch_shell
handler(stream, idents, msg)
File "/usr/local/lib/python3.7/dist-packages/ipykernel/kernelbase.py", line 399, in execute_request
user_expressions, allow_stdin)
File "/usr/local/lib/python3.7/dist-packages/ipykernel/ipkernel.py", line 208, in do_execute
res = shell.run_cell(code, store_history=store_history, silent=silent)
File "/usr/local/lib/python3.7/dist-packages/ipykernel/zmqshell.py", line 537, in run_cell
return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
File "/usr/local/lib/python3.7/dist-packages/IPython/core/interactiveshell.py", line 2718, in run_cell
interactivity=interactivity, compiler=compiler, result=result)
File "/usr/local/lib/python3.7/dist-packages/IPython/core/interactiveshell.py", line 2822, in run_ast_nodes
if self.run_code(code, result):
File "/usr/local/lib/python3.7/dist-packages/IPython/core/interactiveshell.py", line 2882, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "<ipython-input-23-7bc819a85eda>", line 47, in <module>
out, _ = cell(x, state)
File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
(Triggered internally at /pytorch/torch/csrc/autograd/python_anomaly_mode.cpp:104.)
allow_unreachable=True, accumulate_grad=True) # allow_unreachable flag
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-23-7bc819a85eda> in <module>()
48
49 print(i)
---> 50 out.mean().backward()
1 frames
/usr/local/lib/python3.7/dist-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
145 Variable._execution_engine.run_backward(
146 tensors, grad_tensors_, retain_graph, create_graph, inputs,
--> 147 allow_unreachable=True, accumulate_grad=True) # allow_unreachable flag
148
149
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
RuntimeError: tensor does not have a device