Consider the following code, which I have extracted as a minimal example from what is actually a more complicated scenario:
import torch
@torch.jit.script
def func_jit(x):
return x.mul(x.tanh())
print(torch.__version__)
X = torch.tensor(1.23, requires_grad=True, dtype=torch.double)
print(torch.autograd.gradcheck(func_jit, X, raise_exception=True, check_undefined_grad=True, check_batched_grad=True, check_backward_ad=True))
X = torch.tensor(1.23, requires_grad=True, dtype=torch.double)
print(torch.autograd.gradcheck(func_jit, X, raise_exception=True, check_undefined_grad=True, check_batched_grad=True, check_backward_ad=True))
This gives the following output:
1.12.1
True
Traceback (most recent call last):
File "PATH/lib/python3.9/site-packages/torch/autograd/gradcheck.py", line 839, in _test_batched_grad
result = vmap(vjp)(torch.stack(grad_outputs))
File "PATH/lib/python3.9/site-packages/torch/_vmap_internals.py", line 271, in wrapped
batched_outputs = func(*batched_inputs)
File "PATH/lib/python3.9/site-packages/torch/autograd/gradcheck.py", line 822, in vjp
results = grad(v)
File "PATH/lib/python3.9/site-packages/torch/autograd/__init__.py", line 276, in grad
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
RuntimeError: Cannot access data pointer of Tensor that doesn't have storage
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "PATH/test_act_funcs.py", line 86, in <module>
print(torch.autograd.gradcheck(func_jit, X, raise_exception=True, check_undefined_grad=True, check_batched_grad=True, check_backward_ad=True))
File "PATH/lib/python3.9/site-packages/torch/autograd/gradcheck.py", line 1414, in gradcheck
return _gradcheck_helper(**args)
File "PATH/lib/python3.9/site-packages/torch/autograd/gradcheck.py", line 1442, in _gradcheck_helper
_test_batched_grad(tupled_inputs, o, i)
File "PATH/lib/python3.9/site-packages/torch/autograd/gradcheck.py", line 845, in _test_batched_grad
raise GradcheckError(
torch.autograd.gradcheck.GradcheckError: While computing batched gradients, got: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
RuntimeError: Cannot access data pointer of Tensor that doesn't have storage
gradcheck or gradgradcheck failed while testing batched gradient computation.
This could have been invoked in a number of ways (via a test that calls
gradcheck/gradgradcheck directly or via an autogenerated test).
If you are adding a new operator, please file an issue and then use one of the
workarounds. The workaround depends on how your test invokes gradcheck/gradgradcheck.
If the test
- manually invokes gradcheck/gradgradcheck, then call gradcheck/gradgradcheck
with `check_batched_grad=False` as a keyword argument.
- is OpInfo-based (e.g., in test_ops_gradients.py), then modify the OpInfo for the test
to have `check_batched_grad=False` and/or `check_batched_gradgrad=False`.
If you're modifying an existing operator that supports batched grad computation,
or wish to make a new operator work with batched grad computation, please read
the following.
To compute batched grads (e.g., jacobians, hessians), we vmap over the backward
computation. The most common failure case is if there is a 'vmap-incompatible
operation' in the backward pass. Please see
NOTE: [How to write vmap-compatible backward formulas]
in the codebase for an explanation of how to fix this.
This error goes away if func_jit
becomes either x.mul(x)
or x.tanh()
, but is also present for x.mul(x.mul(x))
. It seems that having more than a single operation causes the error? What is going on here?
Setting check_batched_grad=False
‘resolves’ the issue. Is batched grad fundamentally incompatible with torch.jit.script
? Note however that with check_batched_grad=True
the first call to torch.autograd.gradcheck
actually succeeded without problems?