I just came across a curious behaviour when using mixed-precision with HuggingFace (SFTTrainer and PEFT) when training Mistral-7b.
The embeddings and layer norms are kept in full precision and therefore the hidden states get silently casted in float32. In their modeling script they therefore cast the states back to half precsion here. I am usually training in bfloat16. However, in the eval loop the precision is suddenly casted to float16.
With autocast enabled they call torch.get_autocast_gpu_dtype(). Now to my question:
What exactly does torch.get_autocast_gpu_dtype() do or how does it infer the dtype? I have checked and it returns bfloat16 in the training loop and float16 during eval.
These are the warnings that i get. The first one is when i enter the training loop and the second one is the eval.
[2024-12-11 15:34:08,308][transformers.models.mistral.modeling_mistral][328][WARNING]: The input hidden states seems to be silently casted in float32, this might be related to the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in torch.bfloat16.
{'loss': 3.8342, 'grad_norm': 34.91254425048828, 'learning_rate': 2.5e-06, 'epoch': 0.01}
[2024-12-11 15:34:35,316][transformers.models.mistral.modeling_mistral][328][WARNING]: The input hidden states seems to be silently casted in float32, this might be related to the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in torch.float16.
and they stem from this part of the modeling script
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)
So torch.get_autocast_gpu_dtype() should always return torch.float16?
For training the autocast context is explicitly set to bfloat16. However, in my custom eval loop
I only used torch.amp.autocast('cuda'). Am I correct in my assumption that therefore float16 is used since it is the default for torch.get_autocast_dtype("cuda") ?