Enable multiprocessing on pytorch XLA for TPU vm

To do multi-process in PyTorch/XLA, all xla related code such as device = xm.xla_device() need to be done in the _mp_fn function. You can’t do it in global scope.