F.scaled_dot_product_attention causing RuntimeError: CUDA error: invalid argument Error

I am trying to use the torch.nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False) within my attention module but i use to get an error with this API, not sure what is causing it. Would appreciate if anyone can help me


File “/usr/local/lib/python3.8/dist-packages/torch/_tensor.py”, line 487, in backward
torch.autograd.backward(
File “/usr/local/lib/python3.8/dist-packages/torch/autograd/init.py”, line 204, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
RuntimeError: CUDA error: invalid argument
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with TORCH_USE_CUDA_DSA to enable device-side assertions.

Which PyTorch version are you using? If I’m not mistaken, this was a known issue in 2.0.x and should already be fixed in the current nightly builds.

I have the below working environment

PyTorch version: 2.1.0a0+fe05266
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.5 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
Clang version: Could not collect
CMake version: version 3.24.1
Libc version: glibc-2.31

Python version: 3.8.10 (default, Mar 13 2023, 10:26:41)  [GCC 9.4.0] (64-bit runtime)
Python platform: Linux-5.15.0-53-generic-x86_64-with-glibc2.29
Is CUDA available: True
CUDA runtime version: 12.1.66
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3080
Nvidia driver version: 530.30.02
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.9.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                    x86_64
CPU op-mode(s):                  32-bit, 64-bit
Byte Order:                      Little Endian
Address sizes:                   43 bits physical, 48 bits virtual
CPU(s):                          24
On-line CPU(s) list:             0-23
Thread(s) per core:              2
Core(s) per socket:              12
Socket(s):                       1
NUMA node(s):                    1
Vendor ID:                       AuthenticAMD
CPU family:                      23
Model:                           113
Model name:                      AMD Ryzen 9 3900 12-Core Processor
Stepping:                        0
Frequency boost:                 enabled
CPU MHz:                         2200.000
CPU max MHz:                     4359.3750
CPU min MHz:                     2200.0000
BogoMIPS:                        6188.49
Virtualization:                  AMD-V
L1d cache:                       384 KiB
L1i cache:                       384 KiB
L2 cache:                        6 MiB
L3 cache:                        64 MiB
NUMA node0 CPU(s):               0-23
Vulnerability Itlb multihit:     Not affected
Vulnerability L1tf:              Not affected
Vulnerability Mds:               Not affected
Vulnerability Meltdown:          Not affected
Vulnerability Mmio stale data:   Not affected
Vulnerability Retbleed:          Mitigation; untrained return thunk; SMT enabled with STIBP protection
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:        Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:        Mitigation; Retpolines, IBPB conditional, STIBP always-on, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds:             Not affected
Vulnerability Tsx async abort:   Not affected
Flags:                           fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr rdpru wbnoinvd arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif v_spec_ctrl umip rdpid overflow_recov succor smca sme sev sev_es

Versions of relevant libraries:
[pip3] flake8==6.0.0
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.23.5
[pip3] pytorch-quantization==2.1.2
[pip3] torch==2.1.0a0+fe05266
[pip3] torch-tensorrt==1.4.0.dev0
[pip3] torchtext==0.13.0a0+fae8e8c
[pip3] torchvision==0.15.0a0
[pip3] triton==2.0.0

It seems you didn’t check the latest nightly, so could you post a minimal and executable code snippet reproducing the issue please, so I could verify the fix?

@ptrblck i will try to create a code snippet for it. BTW, the error comes while in the backpropagation phase.

@ptrblck Below is the code executable code snippet. It seems that it’s working fine with the F.scaled_dot_product_attention, but not sure in my case i am getting the error !

import torch
import torch.nn as nn
import torch.optim as optim
import math
import torch.nn.functional as F

# Define the SelfAttention layer using simple torch.matmul functions
class CausalSelfAttention(nn.Module):
    def __init__(self, hidden_dim, num_heads):
        super(CausalSelfAttention, self).__init__()
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        
        self.query_linear = nn.Linear(hidden_dim, hidden_dim)
        self.key_linear = nn.Linear(hidden_dim, hidden_dim)
        self.value_linear = nn.Linear(hidden_dim, hidden_dim)
        
    def forward(self, src):
        batch_size, seq_len, hidden_dim = src.size()
        
        # Linear transformations of query, key, and value
        query = self.query_linear(src)
        key = self.key_linear(src)
        value = self.value_linear(src)
        
        # Reshape query, key, and value for multiple heads
        query = query.view(batch_size, seq_len, self.num_heads, hidden_dim // self.num_heads)
        key = key.view(batch_size, seq_len, self.num_heads, hidden_dim // self.num_heads)
        value = value.view(batch_size, seq_len, self.num_heads, hidden_dim // self.num_heads)
        
        # Transpose for batch matrix multiplication
        query = query.transpose(1, 2)
        key = key.transpose(1, 2)
        value = value.transpose(1, 2)
        
        # Compute attention scores
        scores = torch.matmul(query, key.transpose(-2, -1)) / (hidden_dim // self.num_heads)**0.5
              
        # Apply softmax to get attention weights
        attention_weights = nn.functional.softmax(scores, dim=-1)
        
        # Compute attention output
        attention_output = torch.matmul(attention_weights, value)
        
        # Transpose attention output
        attention_output = attention_output.transpose(1, 2)
        
        # Reshape attention output
        attention_output = attention_output.contiguous().view(batch_size, seq_len, hidden_dim)

        return attention_output
    

# Define the ScaledSelfAttention layer using F.scaled_dot_product_attention
class ScaledSelfAttention(nn.Module):
    def __init__(self, hidden_dim, num_heads):
        super(ScaledSelfAttention, self).__init__()
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        
        self.query_linear = nn.Linear(hidden_dim, hidden_dim)
        self.key_linear = nn.Linear(hidden_dim, hidden_dim)
        self.value_linear = nn.Linear(hidden_dim, hidden_dim)
        
    def forward(self, src):
        batch_size, seq_len, hidden_dim = src.size()
        
        # Linear transformations of query, key, and value
        query = self.query_linear(src)
        key = self.key_linear(src)
        value = self.value_linear(src)
        
        # Reshape query, key, and value for multiple heads
        query = query.view(batch_size, seq_len, self.num_heads, hidden_dim // self.num_heads)
        key = key.view(batch_size, seq_len, self.num_heads, hidden_dim // self.num_heads)
        value = value.view(batch_size, seq_len, self.num_heads, hidden_dim // self.num_heads)
        
        # Transpose for batch matrix multiplication
        query = query.transpose(1, 2)
        key = key.transpose(1, 2)
        value = value.transpose(1, 2)
        
        
        # Compute attention scores
        attention_output = F.scaled_dot_product_attention(query, key, value)
        attention_output = torch.einsum('bijk->bjik', attention_output)
        B,N,H,D = attention_output.shape
        attention_output =attention_output.reshape(B,N,H*D)    
        return attention_output
    

# Define the Transformer model
class Transformer(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers, num_heads):
        super(Transformer, self).__init__()
        self.encoder = nn.Embedding(input_dim, hidden_dim)
        self.positional_encoder = PositionalEncoder(hidden_dim)

        # self.attention = CausalSelfAttention(hidden_dim, num_heads)
        self.attention = ScaledSelfAttention(hidden_dim, num_heads)
        
        self.fc = nn.Linear(hidden_dim, hidden_dim)
    
    def forward(self, src):
        src = self.encoder(src)
        src = self.positional_encoder(src)
        output = self.attention(src)
        # print("output:", output.shape)
        output = self.fc(output)
        return output

# Positional encoding
class PositionalEncoder(nn.Module):
    def __init__(self, d_model, max_seq_len=100):
        super(PositionalEncoder, self).__init__()
        self.d_model = d_model
        self.pos_enc = self.positional_encoding(max_seq_len, d_model)
    
    def positional_encoding(self, max_seq_len, d_model):
        position = torch.arange(0, max_seq_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pos_enc = torch.zeros(max_seq_len, d_model)
        pos_enc[:, 0::2] = torch.sin(position * div_term)
        pos_enc[:, 1::2] = torch.cos(position * div_term)
        return pos_enc.unsqueeze(0)
    
    def forward(self, x):
        x = x * math.sqrt(self.d_model)
        x = x + self.pos_enc[:, :x.size(1), :]
        return x

# Define the model hyperparameters
input_dim = 100  # Input vocabulary size
hidden_dim = 256  # Hidden dimension size
num_layers = 2  # Number of transformer layers
num_heads = 4  # Number of attention heads
num_epochs = 5  # Number of training epochs
learning_rate = 0.001  # Learning rate

# Create an instance of the Transformer model
model = Transformer(input_dim, hidden_dim, num_layers, num_heads)

# Define a dummy input and target
src = torch.randint(0, input_dim, (16, 10))  # Batch size = 16, Sequence length = 10
target = torch.zeros(16, 10, hidden_dim)  # Dummy target of the same shape as the output

# Define the loss function and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Training loop
for epoch in range(num_epochs):
    optimizer.zero_grad()
    output = model(src)
    loss = criterion(output, target)
    loss.backward()
    optimizer.step()
    
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item()}")

# Test the model on a new input after training
test_input = torch.randint(0, input_dim, (1, 10))  # Batch size = 1, Sequence length = 10
test_output = model(test_input)
print(test_output.shape)  # Output shape: (1, 10, 256)

I am not sure but I have solved the issue, as I was mistakenly using dropout twice, one from in with the API call and the second one after it. I just removed the one outside from the API and it’s working.

Good to hear it’s working now, but I’m unsure which dropout layers you removed as none are explicitly used in the posted code snippet.
Were you able to reproduce the issue using your post snippet and if so, what did you change to make it work again?