RuntimeError: expected scalar type Float but found BFloat16

I’m running my model in bf16-mixed precision and encountering the following error:

RuntimeError: expected scalar type Float but found BFloat16

This happens when I apply an nn.TransformerEncoder model to my data.

My input consists of two different modalities:

Tensor H: Shape (batch_size, 10000). I pass this through a 1D CNN feature extractor, then into a nn.TransformerEncoder. This entire pipeline works fine under bf16-mixed precision.

Tensor C: Shape (batch_size, 80). Here, I generate “tokens”, apply a padding mask, and then use an nn.Embedding layer before passing it to the transformer. This is where I hit the error mentioned above.

Since everything works for Tensor H, I suspect this issue may be specific to how nn.Embedding behaves under bf16-mixed precision.

Below the code:

# ... other stuff above..

    if self.spectrum_type == 'cnmr' and self.cnmr_binary: # This is Tensor A
        # Tokenize the binary CNMR data

        tokens = self._tokenize_cnmr(x)
        
        # Create mask for transformer (True values will be ignored)
        mask = (tokens == 0).bool()
        
        # Embed the tokens (this is our feature extraction)
        x = self.feature_extractor(tokens)  # [batch_size, seq_len, d_model
        
        # Apply transformer with mask
        x = self.transformer(x, src_key_padding_mask=mask) # This throws the error :'(

        return x, mask
        
    else:  # Below Tensor H
        # Add channel dimension
        x = x.unsqueeze(1)  # [batch_size, 1, sequence_length]
        
        # Apply integrated feature extraction (including pooling if configured)
        x = self.feature_extractor(x)  # [batch_size, channels, seq_len]
        
        # Reshape for projection
        x = x.transpose(1, 2)  # [batch_size, seq_len, channels]
        x = self.post_conv_proj(x)  # [batch_size, seq_len, d_model]
        
        # Add positional encoding if needed
        if self.use_pos_encoding:
            x = self.pos_encoding(x)
        
        # Apply transformer encoder
        x = self.transformer(x)
        
        # Create empty mask (no padding)
        mask = torch.zeros(x.shape[0], x.shape[1], dtype=torch.bool, device=x.device)
        
        return x, mask

# ... other stuff below...

Any guidance on how to resolve this, or whether explicit casting is needed somewhere, would be appreciated!

Are you using autocast or do you manually cast the data?