We have a strange problem with a (JIT) compiled model. We have two pytest
unit tests of our model. In both tests the model is initialized with the same parameters, compiled via torch.script.jit(model)
and fed with toy data, but only the second test uses the output to calculate a loss and eventually calls backward
on it. If both tests are run, the second test throws:
> Variable._execution_engine.run_backward(
tensors, grad_tensors_, retain_graph, create_graph, inputs,
allow_unreachable=True, accumulate_grad=True) # allow_unreachable flag
E RuntimeError: The following operation failed in the TorchScript interpreter.
E Traceback of TorchScript (most recent call last):
E File "<string>", line 162, in <backward op>
E def abs(self):
E def backward(grad_output):
E return grad_output * self.sign()
E ~~~~~~~~~ <--- HERE
E
E return torch.abs(self), backward
E RuntimeError: Unlike NumPy, torch.sign is not intended to support complex numbers. Please use torch.sgn instead.
../miniconda3/envs/nops/lib/python3.8/site-packages/torch/autograd/__init__.py:147: RuntimeError
which seems like a legit error (although we don’t understand where exactly this happens since all leaf parameters are real-valued scalars) but, more importantly, everything works fine if we either
- skip the first test,
- or skip one of the (JIT) compilation steps,
- or use different parameters to initialize the model in the first or second test.
There is no shared state between both test cases and we don’t know what could possible causes this problem. Our model code is a bit convoluted but the test code is rather short:
def test_a():
x = ...
model_args = ...
model = fno.FNO(model_args)
model = torch.jit.script(model)
out = model(x)
def test_b():
x = torch.rand(...)
y = torch.rand(...)
model_args = ...
model = Model(model_args)
model = torch.jit.script(model)
model.train()
out = model(x)
mse = F.mse_loss(out, y, reduction="mean")
mse.backward()
We are tempted to assume that this is a pytorch
bug but we don’t know how to proceed from here on. Do you have any ideas how to nail down the exact problem?
EDIT: we managed to narrow down the error:
def test_c():
x = torch.rand(...)
y = torch.rand(...)
model_args = ...
model = fno.FNO(model_args)
model = torch.jit.script(model) # (1)
model(x) # (2)
model.train()
out = model(x)
mse = F.mse_loss(out, y, reduction="mean")
mse.backward()
test_c()
now triggers the error on its own. Removing either line (1)
or (2)
suppresses the error…