We are using torch to accelerate the computation of some iterative solvers within a bigger project.
The code runs perfectly when using CPU capabilities but when using GPU we run into this runtime error the second time we run the code.
The use of torch is enclosed within a method that is called several times during the whole run and it doesn’t produce problems. The issue appears if we use a script to run this code several times. Then the second time it produces this runtime error the first time it tries to allocate a tensor.
kspace_torch = torch.tensor(kspace, dtype=torch.complex64, device=device)
RuntimeError: CUDA error: invalid argument
CUDA kernel errors might be asynchronously reported at some other API call, so the stack trace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with TORCH_USE_CUDA_DSA to enable device-side assertions.
If we run twice the code in separate calls it does not produce any problem, so we assume it is a problem of not freeing well the resources either for torch or the other parts of the code that use GPU. Does anyone have any clue of what could be happening?
Could you update PyTorch to the latest nightly release and check if you would still see the error?
We were seeing similar issues in 2.0.0 (should have been fixed in 2.0.1 if I’m not mistaken) for sm_89 devices using the BetterTransformer backend.
If this still doesn’t help, could you post a minimal and executable code snippet reproducing the issue, please?
Thanks for the answer! We’re actually using the 2.0.1 but I also tried with the nightly release and it didn’t help.
I wish I could reproduce it in a minimal code snippet, but this code also runs some wrappers for internal code written in C++ that uses the GPU. The problem must come from the interaction with these wrappers, since we’re not able to reproduce the problem without those calls.
Could you explain why kind of “performance in the CUDA side” issues you are seeing that warrants a device reset?
I don’t think you can recover from it on the PyTorch side and would also need to restart the Python process.