Error detected in torch::jit::(anonymous namespace)::DifferentiableGraphBackward

Hi,
Currently i have this error when using chunk on forward after first iteration. but if chunk is replace with split, the error not happen. is there any fix related these problem?

is split same as chunk on term of backward?

this is minimal code i can reproduce

import torch
from torch import jit
from torch.nn import Parameter

class CustomRNNCell(jit.ScriptModule):
    def __init__(self, input_size, hidden_size):
        super(CustomRNNCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.weight_ih = Parameter(torch.randn(3 * hidden_size, input_size))
        self.weight_hh = Parameter(torch.randn(3 * hidden_size, hidden_size))
        self.bias_ih = Parameter(torch.randn(3 * hidden_size))
        self.bias_hh = Parameter(torch.randn(3 * hidden_size))

    @jit.script_method
    def forward(self, input, state):
        # type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
        hx, cx = state
        gates = (torch.mm(input, self.weight_ih.t()) + self.bias_ih +
                 torch.mm(hx, self.weight_hh.t()) + self.bias_hh)
        
        m, o, i = gates.chunk(3, 1)

        m = torch.sigmoid(m)
        o = torch.tanh(o)
        i = torch.tanh(i)        

        cy = (1 - m) * cx + (m * i)
        hy = (1 - o) * i + (o * cx)

        return hy, (hy, cy)

torch.autograd.set_detect_anomaly(mode=True)

cell = CustomRNNCell(
    input_size=1280,
    hidden_size=256
)

for i in range(20):
    x = torch.randn(8, 1280)
    state = (
        torch.zeros(8, 256),
        torch.zeros(8, 256)
    )

    out, _ = cell(x, state)
    
    print(i)
    out.mean().backward()

and the error message

0
1
/usr/local/lib/python3.7/dist-packages/torch/autograd/__init__.py:147: UserWarning: Error detected in torch::jit::(anonymous namespace)::DifferentiableGraphBackward. Traceback of forward call that caused the error:
  File "/usr/lib/python3.7/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/usr/lib/python3.7/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py", line 16, in <module>
    app.launch_new_instance()
  File "/usr/local/lib/python3.7/dist-packages/traitlets/config/application.py", line 845, in launch_instance
    app.start()
  File "/usr/local/lib/python3.7/dist-packages/ipykernel/kernelapp.py", line 499, in start
    self.io_loop.start()
  File "/usr/local/lib/python3.7/dist-packages/tornado/platform/asyncio.py", line 132, in start
    self.asyncio_loop.run_forever()
  File "/usr/lib/python3.7/asyncio/base_events.py", line 541, in run_forever
    self._run_once()
  File "/usr/lib/python3.7/asyncio/base_events.py", line 1786, in _run_once
    handle._run()
  File "/usr/lib/python3.7/asyncio/events.py", line 88, in _run
    self._context.run(self._callback, *self._args)
  File "/usr/local/lib/python3.7/dist-packages/tornado/platform/asyncio.py", line 122, in _handle_events
    handler_func(fileobj, events)
  File "/usr/local/lib/python3.7/dist-packages/tornado/stack_context.py", line 300, in null_wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/zmq/eventloop/zmqstream.py", line 451, in _handle_events
    self._handle_recv()
  File "/usr/local/lib/python3.7/dist-packages/zmq/eventloop/zmqstream.py", line 480, in _handle_recv
    self._run_callback(callback, msg)
  File "/usr/local/lib/python3.7/dist-packages/zmq/eventloop/zmqstream.py", line 434, in _run_callback
    callback(*args, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/tornado/stack_context.py", line 300, in null_wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/ipykernel/kernelbase.py", line 283, in dispatcher
    return self.dispatch_shell(stream, msg)
  File "/usr/local/lib/python3.7/dist-packages/ipykernel/kernelbase.py", line 233, in dispatch_shell
    handler(stream, idents, msg)
  File "/usr/local/lib/python3.7/dist-packages/ipykernel/kernelbase.py", line 399, in execute_request
    user_expressions, allow_stdin)
  File "/usr/local/lib/python3.7/dist-packages/ipykernel/ipkernel.py", line 208, in do_execute
    res = shell.run_cell(code, store_history=store_history, silent=silent)
  File "/usr/local/lib/python3.7/dist-packages/ipykernel/zmqshell.py", line 537, in run_cell
    return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/IPython/core/interactiveshell.py", line 2718, in run_cell
    interactivity=interactivity, compiler=compiler, result=result)
  File "/usr/local/lib/python3.7/dist-packages/IPython/core/interactiveshell.py", line 2822, in run_ast_nodes
    if self.run_code(code, result):
  File "/usr/local/lib/python3.7/dist-packages/IPython/core/interactiveshell.py", line 2882, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-23-7bc819a85eda>", line 47, in <module>
    out, _ = cell(x, state)
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
 (Triggered internally at  /pytorch/torch/csrc/autograd/python_anomaly_mode.cpp:104.)
  allow_unreachable=True, accumulate_grad=True)  # allow_unreachable flag
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-23-7bc819a85eda> in <module>()
     48 
     49     print(i)
---> 50     out.mean().backward()

1 frames
/usr/local/lib/python3.7/dist-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    145     Variable._execution_engine.run_backward(
    146         tensors, grad_tensors_, retain_graph, create_graph, inputs,
--> 147         allow_unreachable=True, accumulate_grad=True)  # allow_unreachable flag
    148 
    149 

RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
RuntimeError: tensor does not have a device

Anyone who can help?

This is a bug in the autodiff, I would recommend to file an issue on the PyTorch github (and crosslink here and the issue). As someone who sometimes looks into PyTorch issues, thank you for making a reproducing example. These are gold to anyone trying to fix things!

Best regards

Thomas

Hi Thomas, big thanks for response before.
Just want to confirm something, actually i want to implement this paper, the code above is just a test to produce bug, but it’s base of this forward code

    @jit.script_method
    def forward(self, x, state):
        # type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]        
        hx, cx = state

        xh = (
            torch.mm(x, self.weight_ih.t()) + self.bias_ih + 
            torch.mm(hx, self.weight_hh.t()) + self.bias_hh
        )        
        
        i, m, o = xh.chunk(3, 1)

        m = m + (self.weight_ch_m * cx)
        o = o + (self.weight_ch_o * cx)

        i = torch.tanh(i)
        m = torch.sigmoid(m)
        o = torch.sigmoid(o)        

        # Base on Formula
        h = (1 - m) * cx + (m * i)
        c = (1 - o) * i + (o * cx)       

        return h, (h, c)    

since the h will be h + (c * 0) to make grad connected to backward, is the implementation of this code is correct for the paper in term of forward and backward? or there is something wrong with my implementation?

Any response will be appreciate, Thanks.