Why the inputs of baddbmm are fp32 the outputs of baddbmm are fp16?

I’m trying to run starcoder with QLORA and find that the loss has been 0 since the very first step.
By debugging, I found that the inputs of baddbmm (attn_weights, query, key) are fp32 and the outputs are fp16.

attn_weights = torch.baddbmm(attn_weights, query, key, beta=beta, alpha=scale_factor).view(attn_shape)

Why the baddbmm has different datatypes for inputs and outputs?

What are the dtypes of all inputs? Also, are you executing the code in an autocast context?

Thanks for your response.
Regarding your questions,
1、no autocast
2、dtype: (attn_weights=float32, query=float32,key=float32,beta=int,alpha=float), return attn_weights=float16

In this case no casting will be performed as seen here:

device = "cuda"
attn_weights = torch.randn(1, 10, 10, device=device)
query = torch.randn(1, 10, 10, device=device)
key = torch.randn(1, 10, 10, device=device)
beta = 1
scale_factor = 1.

attn_weights = torch.baddbmm(attn_weights, query, key, beta=beta, alpha=scale_factor)
# torch.float32

The issue appears in the Hugging Face Transformers code. The code file is located as follows:


The _attn function within this file is causing the problem.

def _attn(self, query, key, value, attention_mask=None, head_mask=None):

    dtype = query.dtype
    softmax_dtype = torch.float32 if self.attention_softmax_in_fp32 else dtype
    upcast = dtype != softmax_dtype

    unscale = self.layer_idx + 1 if self.scale_attention_softmax_in_fp32 and upcast else 1
    scale_factor = unscale**-1
    if self.scale_attn_weights:
        scale_factor /= self.head_dim**0.5

    query_shape = query.shape
    batch_size = query_shape[0]
    key_length = key.size(-1)
    if self.multi_query:
        # (batch_size, query_length, num_heads, head_dim) x (batch_size, head_dim, key_length)
        # -> (batch_size, query_length, num_heads, key_length)
        query_length = query_shape[1]
        attn_shape = (batch_size, query_length, self.num_heads, key_length)
        attn_view = (batch_size, query_length * self.num_heads, key_length)
        # No copy needed for MQA 2, or when layer_past is provided.
        query = query.reshape(batch_size, query_length * self.num_heads, self.head_dim)
        # (batch_size, num_heads, query_length, head_dim) x (batch_size, num_heads, head_dim, key_length)
        # -> (batch_size, num_heads, query_length, key_length)
        query_length = query_shape[2]
        attn_shape = (batch_size, self.num_heads, query_length, key_length)
        attn_view = (batch_size * self.num_heads, query_length, key_length)
        # Always copies
        query = query.reshape(batch_size * self.num_heads, query_length, self.head_dim)
        # No copy when layer_past is provided.
        key = key.reshape(batch_size * self.num_heads, self.head_dim, key_length)

    attn_weights = torch.empty(attn_view, device=query.device, dtype=query.dtype)
    if query.device.type == "cpu":
        # This is needed because of a bug in pytorch https://github.com/pytorch/pytorch/issues/80588.
        # The bug was fixed in https://github.com/pytorch/pytorch/pull/96086,
        # but the fix has not been released as of pytorch version 2.0.0.
        beta = 1
        beta = 0
     **##!!error here!!**
    attn_weights = torch.baddbmm(attn_weights, query, key, beta=beta, alpha=scale_factor).view(attn_shape)

Additional Information of testing environment…

Collecting environment information…
PyTorch version: 2.0.0+cu118
Is debug build: False
CUDA used to build PyTorch: 11.8
ROCM used to build PyTorch: N/A

OS: Microsoft Windows Server 2019 Datacenter
GCC version: (tdm64-1) 5.1.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: N/A

Python version: 3.9.16 (main, Mar 8 2023, 10:39:24) [MSC v.1916 64 bit (AMD64)] (64-bit runtime)
Python platform: Windows-10-10.0.17763-SP0
Is CUDA available: True
CUDA runtime version: 11.6.55
GPU models and configuration:
GPU 0: Tesla V100-PCIE-32GB
GPU 1: Tesla V100-PCIE-32GB
GPU 2: Tesla V100-PCIE-32GB
GPU 3: Tesla V100-PCIE-32GB

Nvidia driver version: 511.65
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Name=Intel(R) Xeon(R) Gold 6242R CPU @ 3.10GHz

Name=Intel(R) Xeon(R) Gold 6242R CPU @ 3.10GHz

Versions of relevant libraries:
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.24.3
[pip3] torch==2.0.0+cu118
[pip3] torchaudio==2.0.2+cu118
[pip3] torchvision==0.15.2+cu118
[conda] numpy 1.24.3 pypi_0 pypi
[conda] torch 2.0.1+cu118 pypi_0 pypi
[conda] torchaudio 2.0.2+cu118 pypi_0 pypi
[conda] torchvision 0.15.2+cu118 pypi_0 pypi

I’ve done exactly the same testing on my lab and got the same result. However, when it runs in the code mentioned above, the output dtype becomes bf16.

I guess HF might use autocast for you behind your back without you realizing it.
PyTorch won’t cast activations down without the user asking for it. Your code is unfortunately not executable, so if you get stuck, narrow it down further and post a minimal and executable code snippet to reproduce the issue.

Thanks for your support! I’ll try to narrow it down.

1 Like

In case you have access to the forward methods you could use torch.is_autocast_enabled() to check if you are inside an autocast context to further narrow it down.

Thanks for the advice! I inserted torch.is_autocast_enabled() in the code and found it is actually set to TRUE.

Thanks for your support! :smiley:

OK, great! Then I would indeed believe a higher level API enables it for you.