Backward pass of scaled_dot_product_attention fails on H100

Hi!

I’m encountering an issue where the backward pass of torch.nn.functional.scaled_dot_product_attention fails on a H100 GPU but doesn’t on an A100 GPU.

I’ve tested this with the following script

import logging
import sys

import torch
import torch.nn.functional as F


def main():
    # setup
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    handler = logging.StreamHandler(stream=sys.stdout)
    formatter = logging.Formatter(
        fmt=f"%(asctime)s %(levelname).1s %(message)s",
        datefmt="%m-%d %H:%M:%S",
    )
    handler.setFormatter(formatter)
    logger.handlers.append(handler)
    device = torch.device("cuda:0")

    # log versions
    logging.info(f"torch.version {torch.__version__}")
    logging.info(f"torch.version.cuda {torch.version.cuda}")
    logging.info(f"device name {torch.cuda.get_device_name()}")
    logging.info(f"device capability {torch.cuda.get_device_capability()}")
    logging.info(f"device properties {torch.cuda.get_device_properties(device)}")

    # init qkv
    dim = 768
    num_heads = 16
    qkv = torch.nn.Linear(dim, dim * 3).to(device)

    # simulate forward pass of a VisionTransformer
    x = torch.randn(4, 197, dim, device=device)
    B, N, C = x.shape
    logging.info("qkv")
    qkv = qkv(x).reshape(B, N, 3, num_heads, C // num_heads).permute(2, 0, 3, 1, 4)
    q, k, v = qkv.unbind(0)

    # forward/backward
    logging.info("scaled_dot_product_attention")
    with torch.autocast("cuda", dtype=torch.bfloat16):
        x = F.scaled_dot_product_attention(q, k, v)
    logging.info("backward")
    x.mean().backward()
    logging.info("fin")


if __name__ == "__main__":
    main()

which results in the following output on a H100:

04-07 11:30:42 I torch.version 2.0.0+cu118
04-07 11:30:42 I torch.version.cuda 11.8
04-07 11:30:42 I device name NVIDIA H100 PCIe
04-07 11:30:42 I device capability (9, 0)
04-07 11:30:42 I device properties _CudaDeviceProperties(name='NVIDIA H100 PCIe', major=9, minor=0, total_memory=81075MB, multi_processor_count=114)
04-07 11:30:42 I qkv
04-07 11:30:44 I scaled_dot_product_attention
04-07 11:30:44 I backward
Traceback (most recent call last):
  File ".../scripts/setup_native_flash_attn.py", line 49, in <module>
    main()
  File ".../scripts/setup_native_flash_attn.py", line 44, in main
    x.mean().backward()
  File ".../lib/python3.9/site-packages/torch/_tensor.py", line 487, in backward
    torch.autograd.backward(
  File ".../lib/python3.9/site-packages/torch/autograd/__init__.py", line 200, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: CUDA error: an illegal instruction was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

and on the following output on a A100:

04-07 11:33:38 I torch.version 2.0.0+cu118
04-07 11:33:38 I torch.version.cuda 11.8
04-07 11:33:38 I device name NVIDIA A100-PCIE-40GB
04-07 11:33:38 I device capability (8, 0)
04-07 11:33:38 I device properties _CudaDeviceProperties(name='NVIDIA A100-PCIE-40GB', major=8, minor=0, total_memory=40384MB, multi_processor_count=108)
04-07 11:33:39 I qkv
04-07 11:33:40 I scaled_dot_product_attention
04-07 11:33:40 I backward
04-07 11:33:40 I fin
  • the same thing happens without mixed-precision
  • setting CUDA_LAUNCH_BLOCKING=1 also doesn’t work

Is the H100 not supported yet, or am I missing something here?

Thanks for reporting the issue!
I can reproduce an illegal instruction using torch==2.0.0+cu118:

========= Illegal instruction
=========     at 0x48f0 in void attention_kernel_backward_batched<AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, (bool)1, (int)64>>(T1::Params)
=========     by thread (0,1,0) in block (0,4,0)
=========     Saved host backtrace up to driver entry point at kernel launch time
=========     Host Frame: [0x30b492]
=========                in /usr/local/cuda/compat/lib.real/libcuda.so.1
=========     Host Frame: [0x1488c]
=========                in /usr/local/lib/python3.8/dist-packages/torch/lib/libcudart-d0da41ae.so.11.0
=========     Host Frame:cudaLaunchKernel [0x6c318]
=========                in /usr/local/lib/python3.8/dist-packages/torch/lib/libcudart-d0da41ae.so.11.0
=========     Host Frame:void attention_kernel_backward_batched<AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, 64> >(AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, 64>::Params) [0x2f386cb]

which points to a wrong kernel dispatching since an sm_80 (Ampere) kernel is launched on sm_90 (Hopper).
The current nightly binary raises a RuntimeError as:

RuntimeError: Expected q_dtype == at::kHalf || (is_sm8x && q_dtype == at::kBFloat16) to be true, but got false.  (Could this error message be improved?  If so, please report an enhancement request to PyTorch.)

We should follow up with enabling the slow path for these devices and eventually also allow the optimized path.

1 Like

I facing same issue with torch==2.0.0+cu118 and H100 GPU. following is my error message.

 File "rlmeta/rlmeta/agents/ppo/ppo_agent.py", line 227, in _train_step                                          
    loss.backward()                                                                                                                 
  File "~/miniconda3/envs/macta/lib/python3.9/site-packages/torch/_tensor.py", line 487, in backward             
    torch.autograd.backward(                                                                                                        
  File "~/miniconda3/envs/macta/lib/python3.9/site-packages/torch/autograd/__init__.py", line 200, in backward   
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass                                  
RuntimeError: CUDA error: an illegal instruction was encountered                                                                    
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

How can I resolve this issue?

You can check for the device capability and switch out the attention with a slow implementation. I did it by wrapping the attention into a seperate module and switching out the module based on the cuda compatibility:

attn_kwargs = dict(
    dim=dim,
    num_heads=num_heads,
    qkv_bias=qkv_bias,
    attn_drop=attn_drop,
    proj_drop=drop,
)
if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 9:
    attn = Attention(**attn_kwargs)
else:
    attn = NativeFlashAttention(**attn_kwargs)
import einops
import torch.nn as nn
import torch.nn.functional as F


class NativeFlashAttention(nn.Module):
    """ timm.models.vision_transformer.Attention but with scaled_dot_product_attention """

    def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
        super().__init__()
        assert hasattr(F, 'scaled_dot_product_attention')
        assert attn_drop == 0, "F.scaled_dot_product_attention dropout has no train/eval mode"
        assert dim % num_heads == 0, 'dim should be divisible by num_heads'
        self.num_heads = num_heads
        self.head_dim = dim // num_heads

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        q, k, v = einops.rearrange(
            self.qkv(x),
            "bs seqlen (three num_heads head_dim) -> three bs num_heads seqlen head_dim",
            three=3,
            num_heads=self.num_heads,
            head_dim=self.head_dim,
        ).unbind(0)

        x = F.scaled_dot_product_attention(q, k, v)

        x = einops.rearrange(x, "bs num_heads seqlen head_dim -> bs seqlen (num_heads head_dim)")
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

as slow implementation I use the one from timm

1 Like

@BenediktAlkin, Thank you for the suggestion. We are using the Transformer model like below.

        self.action_embed = nn.Embedding(self.action_dim,
                                         self.action_embed_dim)
        self.step_embed = nn.Embedding(self.step_dim, self.step_embed_dim)

        self.linear_i = nn.Linear(self.input_dim, self.hidden_dim)
        # self.linear_o = nn.Linear(self.hidden_dim * self.window_size,
        #                           self.hidden_dim)

        encoder_layer = nn.TransformerEncoderLayer(d_model=self.hidden_dim,
                                                   nhead=8,
                                                   dropout=0.0)
        self.encoder = nn.TransformerEncoder(encoder_layer, self.num_layers)

How can I modify this for a slow path ?

Thanks in advance.

I tried modifying the attention layer (MultiheadAttention in our case) to bypass the fast path. I simply either commented or set the ‘why_not_fast_path’ variable to some string in the ‘forward()’ function like below.

class AutocatMultiheadAttention(Module):
    def forward(.....):
        why_not_fast_path = "H100 & Transformer issue"

My assumption was that this will bypass fastpath and force slowpath. but I still face below error.

RuntimeError: CUDA error: an illegal instruction was encountered 

Am I missing something here? Is there any other way to force slowpath?

I was able to figure out the issue somehow. The issue was originally due to the fast path. I had to modify the Transformer and attention code to force recalculate the weights. This resolved my Issue.

I came into the same problem when using loss.backward(), when I swithed torch to Preview(Nightly) version, it worked.