Pytorch/xla SPMD on Cloud TPU pod

I’m currently trying to train a Vision Transformer on a Cloud TPU Pod, using Pytorch/xla SPMD, I followed the later’s documentation, and my code is mostly inspired by their example of running Resnet with SPMD on a single TPU device.

However, as quoted from their docs below:

There is no code change required to go from single TPU host to TPU Pod if you construct your mesh and partition spec based on the number of devices instead of some hardcode constant. To run the PyTorch/XLA workload on TPU Pod.

I truly couldn’t understand how it is possible to run the exact same code and distribute it over all TPU pod VMs only by invoking the same gcloud command with the --worker=all argument, does it work out of the box? how can I check if the data and weights are sharded across all hosts in the pod?

I’m finding it difficult to use Pytorch/xla with a lack of documentation and examples on TPU Pod.

Hi Abdelkrim, the best way to monitor your workload (and see if it’s utilizing all the devices) would be to profile, PyTorch XLA performance profiling  |  Cloud TPU  |  Google Cloud – and check the profiled traces.

You can also enable some loggings, to see if your model is being executed across multiple devices, e.g., TF_CPP_MIN_LOG_LEVEL=0 TF_CPP_VMODULE=xla_graph_executor=5 python test/spmd/ --fake_data --num_epoch=1 --sharding batch

This should print out useful debugging info, like

2023-11-21 01:49:24.349482: I torch_xla/csrc/xla_graph_executor.cpp:994] Executing IR graph hash 4b2ecd7c298db70445084a8ad90e9fd9 on devices: TPU:0,TPU:1,TPU:2,TPU:3 done!
2023-11-21 01:49:24.352578: I torch_xla/csrc/xla_graph_executor.cpp:603] waiting barrier for device SPMD:0 done

on each host. Here, Executing IR graph hash 4b2ecd7c298db70445084a8ad90e9fd9 on devices: TPU:0,TPU:1,TPU:2,TPU:3 done! means that the execution on this host was done with all 4 local devices.

When you run your training on a pod slice, with --worker=all --command={your command}, could you try adding {your command} 2>&1 | tee output.txt} and check the logs afterwards on each host?

Hi Yeounoh, Thank you very much, and I apologize for the late response as I paused the project to study for exams.

I just tried setting the environment variables as per your instructions and I got a huge file containing debugging info, the screenshot below contains some information similar to your example,