RuntimeError: invalid dtype for bias - should match query's dtype #144018

I am training the X-CLIP model using a multi-GPU setup (3 GPUs). However, when I start the training process, I encounter the following error:

" RuntimeError: invalid dtype for bias - should match query’s dtype "

Here is the complete traceback of the error:

UserWarning: torch.utils.checkpoint.checkpoint_sequential: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.
warnings.warn(
[rank0]: Traceback (most recent call last):
[rank0]: File “/data/Hayat_Research_Data/VideoX/X-CLIP/main.py”, line 283, in
[rank0]: main(config)
[rank0]: File “/data/Hayat_Research_Data/VideoX/X-CLIP/main.py”, line 104, in main
[rank0]: train_one_epoch(epoch, model, criterion, optimizer, lr_scheduler, train_loader, text_labels, config, mixup_fn)
[rank0]: File “/data/Hayat_Research_Data/VideoX/X-CLIP/main.py”, line 149, in train_one_epoch
[rank0]: output = model(images, texts)
[rank0]: File “/home/hayatullah/anaconda3/envs/VFL/lib/python3.10/site-packages/torch/nn/modules/module.py”, line 1736, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: File “/home/hayatullah/anaconda3/envs/VFL/lib/python3.10/site-packages/torch/nn/modules/module.py”, line 1747, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: File “/home/hayatullah/anaconda3/envs/VFL/lib/python3.10/site-packages/torch/nn/parallel/distributed.py”, line 1643, in forward
[rank0]: else self._run_ddp_forward(*inputs, **kwargs)
[rank0]: File “/home/hayatullah/anaconda3/envs/VFL/lib/python3.10/site-packages/torch/nn/parallel/distributed.py”, line 1459, in _run_ddp_forward
[rank0]: return self.module(*inputs, **kwargs) # type: ignore[index]
[rank0]: File “/home/hayatullah/anaconda3/envs/VFL/lib/python3.10/site-packages/torch/nn/modules/module.py”, line 1736, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: File “/home/hayatullah/anaconda3/envs/VFL/lib/python3.10/site-packages/torch/nn/modules/module.py”, line 1747, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: File “/data/Hayat_Research_Data/VideoX/X-CLIP/models/xclip.py”, line 135, in forward
[rank0]: text_features = self.cache_text(text)
[rank0]: File “/data/Hayat_Research_Data/VideoX/X-CLIP/models/xclip.py”, line 125, in cache_text
[rank0]: self.cache_text_features = self.encode_text(text)
[rank0]: File “/data/Hayat_Research_Data/VideoX/X-CLIP/models/xclip.py”, line 97, in encode_text
[rank0]: x = self.transformer(x)
[rank0]: File “/home/hayatullah/anaconda3/envs/VFL/lib/python3.10/site-packages/torch/nn/modules/module.py”, line 1736, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: File “/home/hayatullah/anaconda3/envs/VFL/lib/python3.10/site-packages/torch/nn/modules/module.py”, line 1747, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: File “/data/Hayat_Research_Data/VideoX/X-CLIP/clip/model.py”, line 87, in forward
[rank0]: return self.resblocks(x)
[rank0]: File “/home/hayatullah/anaconda3/envs/VFL/lib/python3.10/site-packages/torch/nn/modules/module.py”, line 1736, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: File “/home/hayatullah/anaconda3/envs/VFL/lib/python3.10/site-packages/torch/nn/modules/module.py”, line 1747, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: File “/home/hayatullah/anaconda3/envs/VFL/lib/python3.10/site-packages/torch/nn/modules/container.py”, line 250, in forward
[rank0]: input = module(input)
[rank0]: File “/home/hayatullah/anaconda3/envs/VFL/lib/python3.10/site-packages/torch/nn/modules/module.py”, line 1736, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: File “/home/hayatullah/anaconda3/envs/VFL/lib/python3.10/site-packages/torch/nn/modules/module.py”, line 1747, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: File “/data/Hayat_Research_Data/VideoX/X-CLIP/clip/model.py”, line 75, in forward
[rank0]: x = x + self.attention(self.ln_1(x))
[rank0]: File “/data/Hayat_Research_Data/VideoX/X-CLIP/clip/model.py”, line 72, in attention
[rank0]: return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
[rank0]: File “/home/hayatullah/anaconda3/envs/VFL/lib/python3.10/site-packages/torch/nn/modules/module.py”, line 1736, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: File “/home/hayatullah/anaconda3/envs/VFL/lib/python3.10/site-packages/torch/nn/modules/module.py”, line 1747, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: File “/home/hayatullah/anaconda3/envs/VFL/lib/python3.10/site-packages/torch/nn/modules/activation.py”, line 1368, in forward
[rank0]: attn_output, attn_output_weights = F.multi_head_attention_forward(
[rank0]: File “/home/hayatullah/anaconda3/envs/VFL/lib/python3.10/site-packages/torch/nn/functional.py”, line 6278, in multi_head_attention_forward
[rank0]: attn_output = scaled_dot_product_attention(
[rank0]: RuntimeError: invalid dtype for bias - should match query’s dtype

Versions

PyTorch version: 2.5.1
Is debug build: False
CUDA used to build PyTorch: 12.4
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.5 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.22.1
Libc version: glibc-2.35

Python version: 3.10.14 (main, May 6 2024, 19:42:50) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-6.8.0-45-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.4.131
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA RTX 6000 Ada Generation
GPU 1: NVIDIA RTX 6000 Ada Generation
GPU 2: NVIDIA RTX 6000 Ada Generation

Nvidia driver version: 550.107.02
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 48 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 64
On-line CPU(s) list: 0-63
Vendor ID: AuthenticAMD
Model name: AMD Ryzen Threadripper PRO 5975WX 32-Cores
CPU family: 25
Model: 8
Thread(s) per core: 2
Core(s) per socket: 32
Socket(s): 1
Stepping: 2
Frequency boost: enabled
CPU max MHz: 7006.6401
CPU min MHz: 1800.0000
BogoMIPS: 7186.58
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local user_shstk clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin brs arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif v_spec_ctrl umip pku ospke vaes vpclmulqdq rdpid overflow_recov succor smca fsrm debug_swap
Virtualization: AMD-V
L1d cache: 1 MiB (32 instances)
L1i cache: 1 MiB (32 instances)
L2 cache: 16 MiB (32 instances)
L3 cache: 128 MiB (4 instances)
NUMA node(s): 1
NUMA node0 CPU(s): 0-63
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Vulnerable: Safe RET, no microcode
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines; IBPB conditional; IBRS_FW; STIBP always-on; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected

Versions of relevant libraries:
[pip3] flake8==7.0.0
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.26.4
[pip3] numpydoc==1.7.0
[pip3] nvidia-cublas-cu12==12.1.3.1
[pip3] nvidia-cuda-cupti-cu12==12.1.105
[pip3] nvidia-cuda-nvrtc-cu12==12.1.105
[pip3] nvidia-cuda-runtime-cu12==12.1.105
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-cufft-cu12==11.0.2.54
[pip3] nvidia-curand-cu12==10.3.2.106
[pip3] nvidia-cusolver-cu12==11.4.5.107
[pip3] nvidia-cusparse-cu12==12.1.0.106
[pip3] nvidia-nccl-cu12==2.20.5
[pip3] nvidia-nvjitlink-cu12==12.4.127
[pip3] nvidia-nvtx-cu12==12.1.105
[pip3] torch==2.5.1
[pip3] torchaudio==2.4.0
[pip3] torchvision==0.20.1
[pip3] triton==3.0.0
[conda] blas 1.0 mkl
[conda] cuda-cudart 12.4.127 0 nvidia
[conda] cuda-cupti 12.4.127 0 nvidia
[conda] cuda-libraries 12.4.1 0 nvidia
[conda] cuda-nvrtc 12.4.127 0 nvidia
[conda] cuda-nvtx 12.4.127 0 nvidia
[conda] cuda-opencl 12.6.77 0 nvidia
[conda] cuda-runtime 12.4.1 0 nvidia
[conda] ffmpeg 4.3 hf484d3e_0 pytorch
[conda] libcublas 12.4.5.8 0 nvidia
[conda] libcufft 11.2.1.3 0 nvidia
[conda] libcurand 10.3.7.77 0 nvidia
[conda] libcusolver 11.6.1.9 0 nvidia
[conda] libcusparse 12.3.1.170 0 nvidia
[conda] libjpeg-turbo 2.0.0 h9bf148f_0 pytorch
[conda] libnvjitlink 12.4.127 0 nvidia
[conda] mkl 2023.1.0 h213fc3f_46344
[conda] mkl-service 2.4.0 py310h5eee18b_1
[conda] mkl_fft 1.3.11 py310h5eee18b_0
[conda] mkl_random 1.2.8 py310h1128e8f_0
[conda] numpy 1.26.4 pypi_0 pypi
[conda] numpydoc 1.7.0 py310h06a4308_0
[conda] nvidia-cublas-cu12 12.1.3.1 pypi_0 pypi
[conda] nvidia-cuda-cupti-cu12 12.1.105 pypi_0 pypi
[conda] nvidia-cuda-nvrtc-cu12 12.1.105 pypi_0 pypi
[conda] nvidia-cuda-runtime-cu12 12.1.105 pypi_0 pypi
[conda] nvidia-cudnn-cu12 9.1.0.70 pypi_0 pypi
[conda] nvidia-cufft-cu12 11.0.2.54 pypi_0 pypi
[conda] nvidia-curand-cu12 10.3.2.106 pypi_0 pypi
[conda] nvidia-cusolver-cu12 11.4.5.107 pypi_0 pypi
[conda] nvidia-cusparse-cu12 12.1.0.106 pypi_0 pypi
[conda] nvidia-nccl-cu12 2.20.5 pypi_0 pypi
[conda] nvidia-nvjitlink-cu12 12.4.127 pypi_0 pypi
[conda] nvidia-nvtx-cu12 12.1.105 pypi_0 pypi
[conda] pytorch 2.5.1 py3.10_cuda12.4_cudnn9.1.0_0 pytorch
[conda] pytorch-cuda 12.4 hc786d27_7 pytorch
[conda] pytorch-mutex 1.0 cuda pytorch
[conda] torchaudio 2.4.0 pypi_0 pypi
[conda] torchtriton 3.1.0 py310 pytorch
[conda] torchvision 0.20.1 pypi_0 pypi
[conda] triton 3.0.0 pypi_0 pypi

I would appreciate if anyone who has encountered and resolved this error could share their solution.

Fails raising the invalid dtype so check all inputs and make sure their dtype is a valid combination.

I have resolved the issue ( :slightly_smiling_face:) by using explicit type casting with torch.autocast(“cuda”). This problem appears to be a potential bug in PyTorch (I believe). When using PyTorch 2 on a CPU instead of CUDA, the error occurs because PyTorch converts the query to torch.bfloat16, which doesn’t match the dtype used for the attention_mask (torch.float32). PyTorch 2 uses a function called scaled_dot_product_attention, which cannot handle mismatched dtypes between the query and the mask.