CUDA error: driver shutting down using cpu code

I have a custom built pytorch-2.0.1 with cuda12.1 (used pytorch/conda-builder:cuda121) and observed an unexpected cuda shutting down error when using cpu only code with autograd.

terminate called after throwing an instance of 'c10::Error'
  what():  CUDA error: driver shutting down
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Exception raised from c10_cuda_check_implementation at /opt/conda/conda-bld/pytorch_1691715301395/work/c10/cuda/CUDAException.cpp:44 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x57 (0x7f3ac24334d7 in /fsx/conda/envs/junpu_gpt3_13/lib/python3.10/site-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::string const&) + 0x64 (0x7f3ac23fd36b in /fsx/conda/envs/junpu_gpt3_13/lib/python3.10/site-packages/torch/lib/libc10.so)
frame #2: c10::cuda::c10_cuda_check_implementation(int, char const*, char const*, int, bool) + 0x118 (0x7f3ac24d8f78 in /fsx/conda/envs/junpu_gpt3_13/lib/python3.10/site-packages/torch/lib/libc10_cuda.so)
frame #3: <unknown function> + 0x19048 (0x7f3ac24ae048 in /fsx/conda/envs/junpu_gpt3_13/lib/python3.10/site-packages/torch/lib/libc10_cuda.so)
frame #4: <unknown function> + 0x44616e2 (0x7f3afd3096e2 in /fsx/conda/envs/junpu_gpt3_13/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so)
frame #5: torch::autograd::Engine::thread_init(int, std::shared_ptr<torch::autograd::ReadyQueue> const&, bool) + 0x2a (0x7f3afd30a40a in /fsx/conda/envs/junpu_gpt3_13/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so)
frame #6: torch::autograd::python::PythonEngine::thread_init(int, std::shared_ptr<torch::autograd::ReadyQueue> const&, bool) + 0x5c (0x7f3b0524c5dc in /fsx/conda/envs/junpu_gpt3_13/lib/python3.10/site-packages/torch/lib/libtorch_python.so)
frame #7: <unknown function> + 0xd3e79 (0x7f3b0aed2e79 in /fsx/conda/envs/junpu_gpt3_13/lib/python3.10/site-packages/torch/lib/../../../.././libstdc++.so.6)
frame #8: <unknown function> + 0x8609 (0x7f3b4c614609 in /lib/x86_64-linux-gnu/libpthread.so.0)
frame #9: clone + 0x43 (0x7f3b4c3d3133 in /lib/x86_64-linux-gnu/libc.so.6)

Below is the minimum reproduction code that i used. when i only call f2() by itself, i can see the error, but if i also call f1() by itself or call both f1() and f2() in any order, the error is not there. It seems setting the device have some significance in this.

import torch

def f1():
    print("f1: device specified, use cuda if exist")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    x = torch.tensor(3.0, requires_grad=True, device=device)
    y = x**2 + 2*x + 1
    y.backward()
    print("Gradient of x:", x.grad)
    learning_rate = 0.1
    with torch.no_grad():
        x -= learning_rate * x.grad
    x.grad.zero_()
    print("Updated x:", x)

def f2():
    print("f2: device not specified")
    x = torch.tensor(3.0, requires_grad=True)
    y = x**2 + 2*x + 1
    y.backward()
    print("Gradient of x:", x.grad)
    learning_rate = 0.1
    with torch.no_grad():
        x -= learning_rate * x.grad
    x.grad.zero_()
    print("Updated x:", x)

f1()
f2()

I cannot reproduce the issue using a nightly with CUDA 12.1 by executing the functions in different orders.

Thanks for taking a look, i will try a new build with the current main branch instead of the 2.0.1 release tag to check further.