Using pytorch with cuda

Hello, I am trying to execute Pytorch with Cuda on two machines. In both cases, I didn’t get the expected results. I will try to explain both cases to see if I am doing something wrong or if there is any incompatibility I cannot see.

Case 1)

  • Cuda installed (nvcc --version): 11.2
  • torch 1.12.1+cu113
  • torchaudio 0.12.1+cu113
  • torchvision 0.13.1+cu113

I run the following code (some parts of it are omitted for simplifications purposes):

print(f"Is CUDA supported by this system? {torch.cuda.is_available()}")
print(f"CUDA version: {torch.version.cuda}")

# Storing ID of current CUDA device
cuda_id = torch.cuda.current_device()
print(f"ID of current CUDA device: {torch.cuda.current_device()}")
       
print(f"Name of current CUDA device: {torch.cuda.get_device_name(cuda_id)}")

device = 'cuda' if torch.cuda.is_available() else 'cpu'

model = model.cuda()
model = model.eval() 

with torch.no_grad():
    for i, (x, y) in enumerate(val_loader):
        x = x.cuda()
        y = y.cuda()
        y_pred = model(x)

The result I get is the following:


Is CUDA supported by this system? True
CUDA version: 11.3
ID of current CUDA device: 0
Name of current CUDA device: NVIDIA A100-PCIE-40GB
inputs torch.Size([1024, 3, 224, 224])  \ rep  0  \ labels  torch.Size([1024])
Segmentation fault (core dumped)

When the model does the inference, fails. I don’t know why, because when I run the example with gdb, it gets stuck, and does not finish.

If I run some commands like nvidia-smi -lms to see if there is some program in execution on GPU, it appears, but with a 0% usage of GPU.

If, for instance, I eliminate x.cuda and y.cuda, the program finishes okay, and with Nvidia-semi, it appears like there is some MiB of the GPU allocated but with 0% usage. I suppose this happens because the model is in GPU, but the inference is not.

I change the code, instead of using model.cuda(), model.to(device) or x.to(device) with the same output.

case 2)

  • Cuda installed (nvcc --version): 11.8
  • torch 2.4.1+cu118
  • torchaudio 2.4.1+cu118
  • torchvision 0.19.1+cu118

same code, and this is the output:

Is CUDA supported by this system? True
CUDA version: 11.8
ID of current CUDA device: 0
Name of current CUDA device: Tesla V100-PCIE-32GB
Segmentation fault (core dumped)

I get the same output with different Cuda configurations and different GPUs, and I can detect not only the Cuda version but also the GPU in my code with functions like current_device() , get_device() or is_available().

Again, if a don’t use x.to(device) or y.to(device) works but with 0% usage. Surprisingly, model.to(device) or model.cuda() never fails.

+-----------------------------------------------------------------------------+
| NVIDIA-SMI 520.61.05    Driver Version: 520.61.05    CUDA Version: 11.8     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Tesla V100-PCIE...  Off  | 00000000:3D:00.0 Off |                    0 |
| N/A   35C    P0    35W / 250W |   1684MiB / 32768MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
+-----------------------------------------------------------------------------+

The only possible answer I find is that somehow I am using the wrong code to send the model and the tensors to GPU, or I don’t have a compatible torch version.

Any help would be appreciated,

Thank you.

Have you tried to run the model on cpu?
You are getting segmentation fault, core dump. It’s a bit generic but typically means error on the C++ side.

Try to run your workload via:

gdb--args python script.py
...
run
...
bt

to print the backtrace of the segfault.

This is the output from gdb:

Thread 66 "python3.8" received signal SIGSEGV, Segmentation fault.
[Switching to Thread 0x7fff5b1aa700 (LWP 25254)]
__memcpy_ssse3 () at ../sysdeps/x86_64/multiarch/memcpy-ssse3.S:2509
2509    ../sysdeps/x86_64/multiarch/memcpy-ssse3.S: No such file or directory.
(gdb) bt
#0  __memcpy_ssse3 () at ../sysdeps/x86_64/multiarch/memcpy-ssse3.S:2509
#1  0x00007ffeefdb5754 in void fbgemm::pack_a_with_im2col_opt<2, 512>(fbgemm::conv_param_t<2> const&, fbgemm::block_type_t const&, unsigned char const*, unsigned char*, int, int*, int, int, bool) ()
   from /mnt/beegfs/gap/izcagal/.local/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so
#2  0x00007ffeefdb60ce in fbgemm::PackAWithIm2Col<unsigned char, int, 2>::pack(fbgemm::block_type_t const&) () from /mnt/beegfs/gap/izcagal/.local/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so
#3  0x00007ffeefd3d8a4 in void fbgemm::fbgemmPacked<fbgemm::PackAWithIm2Col<unsigned char, int, 2>, fbgemm::PackBMatrix<signed char, int>, unsigned char, fbgemm::ReQuantizeOutput<true, (fbgemm::QuantizationGranularity)2, float, unsigned char, int, fbgemm::DoNothing<unsigned char, unsigned char> > >(fbgemm::PackMatrix<fbgemm::PackAWithIm2Col<unsigned char, int, 2>, fbgemm::PackAWithIm2Col<unsigned char, int, 2>::inpType, fbgemm::PackAWithIm2Col<unsigned char, int, 2>::accType>&, fbgemm::PackMatrix<fbgemm::PackBMatrix<signed char, int>, fbgemm::PackBMatrix<signed char, int>::inpType, fbgemm::PackBMatrix<signed char, int>::accType>&, unsigned char*, int*, unsigned int, fbgemm::ReQuantizeOutput<true, (fbgemm::QuantizationGranularity)2, float, unsigned char, int, fbgemm::DoNothing<unsigned char, unsigned char> > const&, int, int, fbgemm::BlockingFactors const*) ()
   from /mnt/beegfs/gap/izcagal/.local/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so
#4  0x00007ffeefd7e815 in int fbgemm::fbgemmConv<fbgemm::ReQuantizeOutput<true, (fbgemm::QuantizationGranularity)2, float, unsigned char, int, fbgemm::DoNothing<unsigned char, unsigned char> >, 2, int>(fbgemm::conv_param_t<2> const&, unsigned char const*, fbgemm::PackWeightsForConv<2, signed char, int>&, fbgemm::ReQuantizeOutput<true, (fbgemm::QuantizationGranularity)2, float, unsigned char, int, fbgemm::DoNothing<unsigned char, unsigned char> >::outType*, int*, fbgemm::ReQuantizeOutput<true, (fbgemm::QuantizationGranularity)2, float, unsigned char, int, fbgemm::DoNothing<unsigned char, unsigned char> >&, int, int, fbgemm::BlockingFactors const*) ()
   from /mnt/beegfs/gap/izcagal/.local/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so
#5  0x00007ffeebd020cb in void at::internal::invoke_parallel<PackedConvWeight<2>::apply_impl<true>(at::Tensor const&, double, long)::{lambda(long, long)#2}>(long, long, long, PackedConvWeight<2>::apply_impl<true>(at::Tensor const&, double, long)::{lambda(long, long)#2} const&) [clone ._omp_fn.0] ()
   from /mnt/beegfs/gap/izcagal/.local/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so
#6  0x00007fffdc5c0405 in ?? ()
   from /mnt/beegfs/gap/izcagal/.local/lib/python3.8/site-packages/torch/lib/libgomp-a34b3233.so.1
#7  0x00007ffff7bbb6db in start_thread (arg=0x7fff5b1aa700) at pthread_create.c:463
--Type <RET> for more, q to quit, c to continue without paging--c
#8  0x00007ffff713f61f in clone () at ../sysdeps/unix/sysv/linux/x86_64/clone.S:95

@ptrblck, could you take a look at the backtrace?

I like using accelerate to optionally place all my data on the gpu.

Thanks for providing the stacktrace! It seems fbgemm is failing with some SSE3 instructions but I don’t know which CPU features are required for its execution. CC @malfet do you know if SSE3 is supported in current fbgemm builds?

@ptrblck I think sse3 is a red herring here (note, that it fails with SIGSEGV not SIGILL), i.e. looks like either src or dst pointer is invalid.
Also, fbgemm just calls to memcpy (for example see FBGEMM/src/PackAWithIm2Col.cc at 49fa9a55ba97a0d655cf2c51bf595caab37638e1 · pytorch/FBGEMM · GitHub ) and than glibc decides to make a dispatch to one or another accelerated flavor based on the CPU instruction set support.

@Izan_C_G do you mind running info registers and disasm commands when exception happens?

@ptrblck @malfet problem solved (more or less)
The source of the problem was not the CUDA installation/version or compatibility with Pytorch. It was a problem with the NN model used to make the inference.

I was working with the Resnet50 Quantized model. If Resnet50 is used with the fp32 datatype, it works fine, and the GPU is used as it should be. However, quantized models always get segmentation faults.

I realized that in torchvision 0.19 API, there is a note saying that quantized models don’t work with GPU. I think this is the problem. I ignored that.

However, the problem I have now, is that I want to use quantized models, and I like to use them in GPU. Is there any possible solution to that?