Hi does anyone knows how to get all predicted labels from all 8 cores of XLA and concatenate them together?
Say I have a model:
outputs = model(ids, mask, token_type_ids)
_, pred_label = torch.max(outputs.data, dim = 1)
If I do
all_predictions_np = pred_label.cpu().detach().numpy().tolist()
apparently, this only sends the result to CPU from TPU core:0. How can I get all 8 cores and concatenate them together in the same list? I am not sure if xm.all_gather()
is used in this case?