Segmentation fault with multiple devices

I was training my model in two RTX4090 GPUS by setting the following:

os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0,1"

I am running the nightly version of PyTorch and it was working fine until a couple of days ago.

Now it produces segmentation fault (core dumped) errors or, occasionally bus errors.

Two things happened in between:

  • I updated pytorch.

  • The driver was updated to version 535.98

Cuda is 12.1.

Is anyone experiencing similar issues with either the new most recent development branches of pytorch or with the new driver?

Using either GPU works fine, but I can no longer use both simultaneously as I used to.

I have downgraded the driver to the last working version (535.86) but it is still impossible to use both GPUs.

The problem only occurs with gradient check-pointing.