AssertionError: torch.float32 != torch.bfloat16

Hi all,

I encountered an issue while trying to use torch.compile to compile a model with FSDP enabled. The assertion error arises from this line during the verification of fake tensors. The model I am compiling is from this file.

After investigating the issue, I found some insights that might be helpful.

Let’s consider the tensor cos_cached in the model code referenced above.

  1. cos_cached is a view, created when emb.cos() generates a tensor followed by a slicing operation.
  2. cos_cached is copied to the GPU here. In this step, the data field of the tensor is directly modified. As a result, cos_cached remains a view, but it no longer shares data with the underlying tensor.
  3. cos_cached is cast according to automatic mixed precision here. In my case, the dtype of cos_cached was originally float32 and is now cast to bf16. Consequently, cos_cached is still a view, it does not share data with the underlying tensor, and their dtypes now differ.

Finally, the type-checking assertion error occurs when creating a fake tensor for cos_cached.

I am wondering if this could be a bug and if directly modifying tensor.data is recommended in such cases.

Thanks,
Yifei