I’m not sure if it is fully related to the issue, but on my 2x4090 I get the following:
import torch
v = torch.randn(5, device='cuda:0')
print(v)
print(v.to('cuda:1'))
print(v.to('cpu').to('cuda:1'))
tensor([ 1.5336, 0.8161, -0.9325, -0.9513, 0.1360], device=‘cuda:0’)
tensor([0., 0., 0., 0., 0.], device=‘cuda:1’)
tensor([ 1.5336, 0.8161, -0.9325, -0.9513, 0.1360], device=‘cuda:1’)
It looks there is a bug (very likely at NVIDIA site) in GPU-to-GPU memory copy. So it sets everything to zeros. I have the latest NVIDIA driver and tried the latest stable PyTorch as well as Pytorch 2.0 preview.
And here is NVIDIA P2P test output:
[./simpleP2P] - Starting...
Checking for multiple GPUs...
CUDA-capable device count: 2
Checking GPU(s) for support of peer to peer memory access...
> Peer access from NVIDIA GeForce RTX 4090 (GPU0) -> NVIDIA GeForce RTX 4090 (GPU1) : Yes
> Peer access from NVIDIA GeForce RTX 4090 (GPU1) -> NVIDIA GeForce RTX 4090 (GPU0) : Yes
Enabling peer access between GPU0 and GPU1...
Allocating buffers (64MB on GPU0, GPU1 and CPU Host)...
Creating event handles...
cudaMemcpyPeer / cudaMemcpy between GPU0 and GPU1: 12.61GB/s
Preparing host buffer and memcpy to GPU0...
Run kernel on GPU1, taking source data from GPU0 and writing to GPU1...
Run kernel on GPU0, taking source data from GPU1 and writing to GPU0...
Copy data back to host from GPU0 and verify results...
Verification error @ element 1: val = 0.000000, ref = 4.000000
Verification error @ element 2: val = 0.000000, ref = 8.000000
Verification error @ element 3: val = 0.000000, ref = 12.000000
Verification error @ element 4: val = 0.000000, ref = 16.000000
Verification error @ element 5: val = 0.000000, ref = 20.000000
Verification error @ element 6: val = 0.000000, ref = 24.000000
Verification error @ element 7: val = 0.000000, ref = 28.000000
Verification error @ element 8: val = 0.000000, ref = 32.000000
Verification error @ element 9: val = 0.000000, ref = 36.000000
Verification error @ element 10: val = 0.000000, ref = 40.000000
Verification error @ element 11: val = 0.000000, ref = 44.000000
Verification error @ element 12: val = 0.000000, ref = 48.000000
Disabling peer access...
Shutting down...
Test failed!