Train Pytorch Model on multi-TPU V3 device based TPU Pod

I want to train a large model on a TPU V3 Pod with 5 TPU devices. I am very novice on TPU. I already code a model which I train on multi-gpu (4 V100) using DataParallel. I found DataParallel is very easy to incorporate. I have a couple of concerns to train this model on google cloud TPU:

  • Can I train the same learning with DataParallel on could TPU V3 device with 5 TPU ? or do need to do any modifications except changing the library to xla?
  • Should I use DataParallel or DistributedDataParallel to train the model on TPU Pod?
  • Does anyone have any experience with pythorch-lightning with multi-tpu device TPU Pod?

Sorry for the novice level questions. Any types of resources, suggestions will be a great help.

cc @ailzhang for TPU questions :slight_smile:

1 Like
  • Should I use DataParallel or DistributedDataParallel to train the model on TPU Pod?

General guidance fo DataParallel vs DistributedDataParallel: Getting Started with Distributed Data Parallel — PyTorch Tutorials 2.1.1+cu121 documentation

1 Like