Function tensor.cpu() takes a lot of time with pytorch xla multicore

I am using pytorch xla with TPU on google colab. As expected computations are indeed faster, however when I go back to cpu to save my tensors, the “action” of going back to the cpu takes a lot of time for big tensors. Here is the function I run in parallel:

def _mp_fn(index, batch_size, model, typee, directory_join):
    device = xm.xla_device()
    xm.master_print(f"Here")
    if not xm.is_master_ordinal():
        xm.rendezvous("download_only_once")

    loader = sherBERTEmbedding(f_df, typee)

    if xm.is_master_ordinal():
        xm.rendezvous("download_only_once")
    xm.master_print(f"There")

    sampler = DistributedSampler(loader, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=False)
    dataloader = DataLoader(
        loader,
        batch_size=batch_size,
        sampler=sampler,
        # drop_last=True,
    )

    model = model.to(device)
    model.eval()

    length = len(dataloader)

    para_loader = pl.ParallelLoader(dataloader, [device])
    num_b = len(dataloader)
    tensor_embeddings = torch.zeros((batch_size*num_b, 153, 768), dtype=torch.float).to(device)
    tensor_ids = torch.zeros((batch_size*num_b, 153), dtype=torch.long).to(device)
    tensor_index = torch.zeros(batch_size*num_b, dtype=torch.long).to(device)

    toremove = 0
    prev_ind = 0

    with torch.no_grad():
        for t, (x, x_pad_mask, index) in tqdm(
            enumerate(para_loader.per_device_loader(device)), total=length
        ):
            x = x.to(device=device, dtype=torch.long)
            x_pad_mask = x_pad_mask.to(device=device, dtype=torch.long)
            result = get_embeddings_best_token(x, x_pad_mask, model_embed, index_best_tokens)
            ids = get_ids_best_tokens(x, index_best_tokens)

            tensor_embeddings[prev_ind: prev_ind + result.size(0), :, :] = result
            tensor_ids[prev_ind: prev_ind + ids.size(0), :] = ids
            tensor_index[prev_ind: prev_ind +index.size(0)] = index.squeeze(1)
            prev_ind += result.size(0)


            if result.size(0) != batch_size:
                toremove = (batch_size - result.size(0))

        if toremove == 0:
            toremove = tensor_embeddings.size(0)+1


        #saving
        directory_embeddings = os.path.join(os.path.join(directory_join, typee), "embeddings")
        directory_ids = os.path.join(os.path.join(directory_join, typee), "ids")
        directory_index = os.path.join(os.path.join(directory_join, typee), "index")

        tosave = tensor_index[:-toremove]
        ltosave = tosave.cpu()
        torch.save(tosave, os.path.join(directory_index, str(xm.get_ordinal())+".p"))
        print("Saving 1")

        tosave = tensor_ids[:-toremove, :].squeeze()
        ltosave = tosave.cpu()
        torch.save(tosave, os.path.join(directory_ids, str(xm.get_ordinal())+".p"))
        print("Saving 2")

        tosave = tensor_embeddings[:-toremove, :, : ].view(-1,768).detach()
        ltosave = tosave.cpu()
        print("Saving 3")
        torch.save(ltosave, os.path.join(directory_embeddings, str(xm.get_ordinal())+".p"))





def get_embeddings_best_token(root, mask, model, index_best_tokens):
    res = model((root, mask))
    return res[:,index_best_tokens[0]:index_best_tokens[-1]+1, :]

def get_ids_best_tokens(x, index_best_tokens):
    return x[:,index_best_tokens[0]:index_best_tokens[-1]+1]

And I run in parallel with this:

t1 = time.time()
get_embeddings(8, model_embed, "myType", JOIN_DIR)
t2 = time.time()

I know that going back to cpu between computations on TPU is a bottleneck, however here I am just trying to save the tensors after the “computations”. I know that a special function exists to save tensors: torch_xla.core.xla_model.save(data, file_or_path, master_only=True, global_master=False) . However, they explain that basically, it converts tensors from xla tensors to cpu tensors and then saves them. When I use this function, it takes around the same amount of time.

Hi, just want to let you know that for all pytorch/xla (or pytorch on tpu) questions, please open an issue in https://github.com/pytorch/xla.
Thanks!

Thanks! I will open an issue, and link it here if it appears to be a problem with my code!