Dtype different for eval and train loop with mixed prescison

I just came across a curious behaviour when using mixed-precision with HuggingFace (SFTTrainer and PEFT) when training Mistral-7b.

The embeddings and layer norms are kept in full precision and therefore the hidden states get silently casted in float32. In their modeling script they therefore cast the states back to half precsion here. I am usually training in bfloat16. However, in the eval loop the precision is suddenly casted to float16.

With autocast enabled they call torch.get_autocast_gpu_dtype(). Now to my question:

What exactly does torch.get_autocast_gpu_dtype() do or how does it infer the dtype? I have checked and it returns bfloat16 in the training loop and float16 during eval.

Thanks for your help!

The return value for the default autocast dtype depends on the device and won’t change in the eval loop as seen here:

print(torch.get_autocast_dtype("cpu"))
# torch.bfloat16
print(torch.get_autocast_dtype("cuda"))
# torch.float16

print(torch.get_autocast_gpu_dtype())
# DeprecationWarning: torch.get_autocast_gpu_dtype() is deprecated. Please use torch.get_autocast_dtype('cuda') instead.
# torch.float16


with torch.no_grad():
    print(torch.get_autocast_dtype("cpu"))
    # torch.bfloat16
    print(torch.get_autocast_dtype("cuda"))
    # torch.float16

    print(torch.get_autocast_gpu_dtype())
    # DeprecationWarning: torch.get_autocast_gpu_dtype() is deprecated. Please use torch.get_autocast_dtype('cuda') instead.
    # torch.float16


class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        
    def forward(self, x):
        print(torch.get_autocast_dtype("cpu"))
        # torch.bfloat16
        print(torch.get_autocast_dtype("cuda"))
        # torch.float16

        print(torch.get_autocast_gpu_dtype())
        # DeprecationWarning: torch.get_autocast_gpu_dtype() is deprecated. Please use torch.get_autocast_dtype('cuda') instead.
        # torch.float16
        return x

model = MyModel()
x = torch.randn(1, 1)
out = model(x)
# torch.bfloat16
# torch.float16
# torch.float16

model.eval()
out = model(x)
# torch.bfloat16
# torch.float16
# torch.float16

with torch.no_grad():
    out = model(x)
    # torch.bfloat16
    # torch.float16
    # torch.float16

Interesting.

These are the warnings that i get. The first one is when i enter the training loop and the second one is the eval.

[2024-12-11 15:34:08,308][transformers.models.mistral.modeling_mistral][328][WARNING]: The input hidden states seems to be silently casted in float32, this might be related to the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in torch.bfloat16.
{'loss': 3.8342, 'grad_norm': 34.91254425048828, 'learning_rate': 2.5e-06, 'epoch': 0.01}                                                                                                                                                                                               
[2024-12-11 15:34:35,316][transformers.models.mistral.modeling_mistral][328][WARNING]: The input hidden states seems to be silently casted in float32, this might be related to the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in torch.float16.

and they stem from this part of the modeling script

        input_dtype = query_states.dtype
        if input_dtype == torch.float32:
            if torch.is_autocast_enabled():
                target_dtype = torch.get_autocast_gpu_dtype()
            # Handle the case where the model is quantized
            elif hasattr(self.config, "_pre_quantization_dtype"):
                target_dtype = self.config._pre_quantization_dtype
            else:
                target_dtype = self.q_proj.weight.dtype

            logger.warning_once(
                f"The input hidden states seems to be silently casted in float32, this might be related to"
                f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
                f" {target_dtype}."
            )

So torch.get_autocast_gpu_dtype() should always return torch.float16?

I think I know now what is happening.

For training the autocast context is explicitly set to bfloat16. However, in my custom eval loop
I only used torch.amp.autocast('cuda'). Am I correct in my assumption that therefore float16 is used since it is the default for torch.get_autocast_dtype("cuda") ?

Yes, if you don’t specify the dtype in autocast the default will be used:

with torch.autocast(device_type="cuda"):
    print(torch.get_autocast_dtype("cpu"))
    print(torch.get_autocast_dtype("cuda"))
# torch.bfloat16
# torch.float16

with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
    print(torch.get_autocast_dtype("cpu"))
    print(torch.get_autocast_dtype("cuda"))
# torch.bfloat16
# torch.bfloat16

Thank you for the confirmation. Now that I know what I was looking for I have also found it in the docs.

And as always, thanks for the quick, precise and helpful answers :slight_smile: