Spooky action at a distance between model compilations

We have a strange problem with a (JIT) compiled model. We have two pytest unit tests of our model. In both tests the model is initialized with the same parameters, compiled via torch.script.jit(model) and fed with toy data, but only the second test uses the output to calculate a loss and eventually calls backward on it. If both tests are run, the second test throws:

>       Variable._execution_engine.run_backward(
            tensors, grad_tensors_, retain_graph, create_graph, inputs,
            allow_unreachable=True, accumulate_grad=True)  # allow_unreachable flag
E       RuntimeError: The following operation failed in the TorchScript interpreter.
E       Traceback of TorchScript (most recent call last):
E         File "<string>", line 162, in <backward op>
E               def abs(self):
E                   def backward(grad_output):
E                       return grad_output * self.sign()
E                                            ~~~~~~~~~ <--- HERE
E
E                   return torch.abs(self), backward
E       RuntimeError: Unlike NumPy, torch.sign is not intended to support complex numbers. Please use torch.sgn instead.

../miniconda3/envs/nops/lib/python3.8/site-packages/torch/autograd/__init__.py:147: RuntimeError

which seems like a legit error (although we don’t understand where exactly this happens since all leaf parameters are real-valued scalars) but, more importantly, everything works fine if we either

  • skip the first test,
  • or skip one of the (JIT) compilation steps,
  • or use different parameters to initialize the model in the first or second test.

There is no shared state between both test cases and we don’t know what could possible causes this problem. Our model code is a bit convoluted but the test code is rather short:

def test_a():
    x = ...
    model_args = ...

    model = fno.FNO(model_args)
    model = torch.jit.script(model)

    out = model(x)


def test_b():
    x = torch.rand(...)
    y = torch.rand(...)
    model_args = ...

    model = Model(model_args)
    model = torch.jit.script(model)

    model.train()
    out = model(x)
    mse = F.mse_loss(out, y, reduction="mean")
    mse.backward()

We are tempted to assume that this is a pytorch bug but we don’t know how to proceed from here on. Do you have any ideas how to nail down the exact problem?

EDIT: we managed to narrow down the error:

def test_c():
    x = torch.rand(...)
    y = torch.rand(...)
    model_args = ...

    model = fno.FNO(model_args)
    model = torch.jit.script(model)  # (1)
    model(x)                         # (2)

    model.train()
    out = model(x)
    mse = F.mse_loss(out, y, reduction="mean")
    mse.backward()

test_c() now triggers the error on its own. Removing either line (1) or (2) suppresses the error…

What happens is that in the second run you get the JIT autodiffed abs and that has been broken when complex support was added to the derivative of backward. Sadly, it seems that while the autograd that you get on the first run had been updated 10 months ago, the JIT has not been getting as much love from the complex number transition.

The fix is trivial in line with the exception: Change sign to sgn here: pytorch/symbolic_script.cpp at 03a79f43e33c1bef65fc8912c27160d01e0e15d5 · pytorch/pytorch · GitHub
However, I have not managed to find a reproducing case, or maybe the nightly versions have a fix where the JIT avoids complex numbers altogether, I would not know.

Best regards

Thomas

1 Like

Thanks a lot for your help! I am afraid, I am still a bit lost with your explanation about “autodiffed abs”. Could you explain this in a bit more detail? In particular, why does this error only occurs after the second compilation / after the second forward evaluation?

The short answer is that this is because the JIT collects information on (by default only) the first run and after this does its optimizations. If you want plenty of detail, I can offer a blog post on what happens when you call a JIT function and one on optimization in particular.

Best regards

Thomas

1 Like

Ok. Thanks for providing the links and your swift help:)

EDIT: using the nightly build I get the same error

Thank you for checking!
So if you can get the example down to where the model is a short sequential model of standard PyTorch or somesuch, either of us could do a PR to fix it.

We are currently quite busy but I will see what we can do. Besides, shouldn’t fixing torch.sign to torch.sgn in symbolic_script.cpp a good idea anyhow?