I am trying XLA on the TPUs for the first time. Have been running all my other training on Flax. I need to move a Whisper finetuning script from GPU to TPU (and finally TPU Pods). I am using a simple TPU VM v4 .
I am able to set up the TPU. The sample scripts run at correct speed. I am also able to start my script without any errors. For me it seems that the xla-hook-script is initiated, and I set the variable “xla”=True in the Trainer script. No error messages. But everything is extremely slow. On a standard A6000 I am able to run around 16s/it. On the TPU I am running 800s/it! It is as if I am training on CPU.
Could anyone give me some basic tips in how to debug this? If you are interested, the script is available here: NbAiLab/whisper · Hugging Face