I’m running my model in bf16-mixed
precision and encountering the following error:
RuntimeError: expected scalar type Float but found BFloat16
This happens when I apply an nn.TransformerEncoder model to my data.
My input consists of two different modalities:
Tensor H: Shape (batch_size, 10000)
. I pass this through a 1D CNN feature extractor, then into a nn.TransformerEncoder
. This entire pipeline works fine under bf16-mixed precision
.
Tensor C: Shape (batch_size, 80)
. Here, I generate “tokens”, apply a padding mask, and then use an nn.Embedding
layer before passing it to the transformer. This is where I hit the error mentioned above.
Since everything works for Tensor H, I suspect this issue may be specific to how nn.Embedding
behaves under bf16-mixed
precision.
Below the code:
# ... other stuff above..
if self.spectrum_type == 'cnmr' and self.cnmr_binary: # This is Tensor A
# Tokenize the binary CNMR data
tokens = self._tokenize_cnmr(x)
# Create mask for transformer (True values will be ignored)
mask = (tokens == 0).bool()
# Embed the tokens (this is our feature extraction)
x = self.feature_extractor(tokens) # [batch_size, seq_len, d_model
# Apply transformer with mask
x = self.transformer(x, src_key_padding_mask=mask) # This throws the error :'(
return x, mask
else: # Below Tensor H
# Add channel dimension
x = x.unsqueeze(1) # [batch_size, 1, sequence_length]
# Apply integrated feature extraction (including pooling if configured)
x = self.feature_extractor(x) # [batch_size, channels, seq_len]
# Reshape for projection
x = x.transpose(1, 2) # [batch_size, seq_len, channels]
x = self.post_conv_proj(x) # [batch_size, seq_len, d_model]
# Add positional encoding if needed
if self.use_pos_encoding:
x = self.pos_encoding(x)
# Apply transformer encoder
x = self.transformer(x)
# Create empty mask (no padding)
mask = torch.zeros(x.shape[0], x.shape[1], dtype=torch.bool, device=x.device)
return x, mask
# ... other stuff below...
Any guidance on how to resolve this, or whether explicit casting is needed somewhere, would be appreciated!