How to concatenate all the predicted labels in XLA?

Hi does anyone knows how to get all predicted labels from all 8 cores of XLA and concatenate them together?

Say I have a model:
outputs = model(ids, mask, token_type_ids)
_, pred_label = torch.max(, dim = 1)

If I do
all_predictions_np = pred_label.cpu().detach().numpy().tolist()

apparently, this only sends the result to CPU from TPU core:0. How can I get all 8 cores and concatenate them together in the same list? I am not sure if xm.all_gather() is used in this case?