PyTorch uses float32 by default on CPU and GPU. I’m not deeply familiar with TPUs, but I guess you might be using bfloat16
on them? Could you try to call float()
on the model and inputs and check, if the TPU run is forcing you to use this format?