Enable multiprocessing on pytorch XLA for TPU vm

I’m fairly new to this and have little to no experience.
I had a notebook running PyTorch that I wanted to run a Google Cloud TPU VM.
Machine specs:

* Ubuntu
* TPU v2-8
* pt-2.0

I should have 8 cores. Correct me if I’m wrong.

So, I followed the guidelines for making the notebook TPU-compatible via XLA. I did the following:

os.environ['PJRT_DEVICE'] = 'TPU'
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp
device = xm.xla_device()

print(device)

It prints xla:0.

*Models was sent to the device via the model.to(device) function.
*Dataloaders was wrapped in a pl.MpDeviceLoader(loader, device)
*Optimizer is stepped via the xm.optimizer_step(optimizer) function

  • As far as I know, this is how to enable multiprocessing:
def _mp_fn(index):
     # models creation
     # data preparation
     # training loop 
if __name__ == '__main__':
  xmp.spawn(_mp_fn, args=())

I get the error:

BrokenProcessPool: A process in the process pool was terminated abruptly while the future was running or pending.

I could be all totally wrong about this. So, I’m sorry for that.
If you need a further look at the code, I can share the notebook if you want.
When I follow the guidelines for single-core processing, and I don’t use xmp.spawn , I get 1.2 iterations/sec which can be significantly increased if used all cores.

To do multi-process in PyTorch/XLA, all xla related code such as device = xm.xla_device() need to be done in the _mp_fn function. You can’t do it in global scope.

2 Likes

I followed an article which led me to write another code but It still doesn’t work.
I state what I did on details here: use-8-cores-on-pytorch-xla-for-tpu-vm.
Can you please check it out?
Thank you so much

Hi Adham,

First, thanks for using PyTorch/XLA. I have looked into your code. One thing you can try is to replace your ParallelLoader with MpDeviceLoader. You can find the usage examples here: xla/test_train_mp_mnist.py at dc97329679fb1baa2d4ea5e9f112e8f58dc9b1b4 · pytorch/xla · GitHub

Will love to diagnose further if the above suggestion doesn’t work. BTW, can you also provide a code that I can easily run locally? Your sample code doesn’t run.

1 Like

@alanwaketan Thank you for responding. I did that and now here’s what’s happening:
When using nprocs=8 I get an exception:

A process in the process pool was terminated abruptly while the future was running or pending

I looked it up and found this GitHub issue which says make it 1 for debugging. It WORKED but it’s 1.44it/s.
It would be great if you helped be to use all the 8 workers.

Here’s my notebook as per your request Colab notebook
The code is the cell below Main. You have an edit access

Just want to confirm that you are using V2-8 not V3-8?

P.S. PJRT doesn’t work on colab. We recommend to use kaggle instead. Here are examples how to use kaggle with pytorch/xla: xla/contrib/kaggle at master · pytorch/xla · GitHub.

Hi Jiewen,

I just saw this - actually, I have been having other issues (not using PJRT) on Colab TPU. I opened an issue, would you be able to take a look?

It would be nice to be able to use Colab since saving checkpoints to my google drive is easier and more reliable. (Kaggle has sometimes evaporated my checkpoints even after selecting ‘Persistence → Variables and Files’.)

Best,

Henry