I’m fairly new to this and have little to no experience.
I had a notebook running PyTorch that I wanted to run a Google Cloud TPU VM.
Machine specs:
* Ubuntu
* TPU v2-8
* pt-2.0
I should have 8 cores. Correct me if I’m wrong.
So, I followed the guidelines for making the notebook TPU-compatible via XLA. I did the following:
os.environ['PJRT_DEVICE'] = 'TPU'
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp
device = xm.xla_device()
print(device)
It prints xla:0.
*Models was sent to the device via the model.to(device) function.
*Dataloaders was wrapped in a pl.MpDeviceLoader(loader, device)
*Optimizer is stepped via the xm.optimizer_step(optimizer) function
As far as I know, this is how to enable multiprocessing:
def _mp_fn(index):
# models creation
# data preparation
# training loop
if __name__ == '__main__':
xmp.spawn(_mp_fn, args=())
I get the error:
BrokenProcessPool: A process in the process pool was terminated abruptly while the future was running or pending.
I could be all totally wrong about this. So, I’m sorry for that.
If you need a further look at the code, I can share the notebook if you want.
When I follow the guidelines for single-core processing, and I don’t use xmp.spawn , I get 1.2 iterations/sec which can be significantly increased if used all cores.
To do multi-process in PyTorch/XLA, all xla related code such as device = xm.xla_device() need to be done in the _mp_fn function. You can’t do it in global scope.
I followed an article which led me to write another code but It still doesn’t work.
I state what I did on details here: use-8-cores-on-pytorch-xla-for-tpu-vm.
Can you please check it out?
Thank you so much
Will love to diagnose further if the above suggestion doesn’t work. BTW, can you also provide a code that I can easily run locally? Your sample code doesn’t run.
@alanwaketan Thank you for responding. I did that and now here’s what’s happening:
When using nprocs=8 I get an exception:
A process in the process pool was terminated abruptly while the future was running or pending
I looked it up and found this GitHub issue which says make it 1 for debugging. It WORKED but it’s 1.44it/s.
It would be great if you helped be to use all the 8 workers.
Here’s my notebook as per your request Colab notebook
The code is the cell below Main. You have an edit access
I just saw this - actually, I have been having other issues (not using PJRT) on Colab TPU. I opened an issue, would you be able to take a look?
It would be nice to be able to use Colab since saving checkpoints to my google drive is easier and more reliable. (Kaggle has sometimes evaporated my checkpoints even after selecting ‘Persistence → Variables and Files’.)