Concurrent CPU execution when Out-of-memory

Hi PyTorch community, I hope this is the right section to post my issue in.

I have a quite big dataset of biological (protein) sequences, around 68M, that I have to process with Rostlab/prot_t5_xl_uniref50 at main, a pre-trained NLP model that generates embeddings for downstream tasks. I have set up a decent workflow that is able to recover when my single V100 GPU goes OOM when the sequence is too big (even with batch size set to 1), effectively resorting to CPU and RAM for the calculations.
As of now, I used the suggested PyTorch implementation in the for loop:

oom = False
    embedding = model(input_ids=input_ids, attention_mask=attention_mask)
except RuntimeError:
    oom = True
if oom:
    embedding = model(input_ids=input_ids.cpu(), attention_mask=attention_mask.cpu())

This approach solves the OOM issues, but I realized that has the side effect of having the GPU sitting around doing nothing while the big sequence is being processed, only to resume working with the next iterations.
Is there any possibility to send the problematic sequence to CPU while continuing to process the rest normally?

Things I’ve tried:

  • DataParallel is not the right answer:ValueError: Expected a non cpu device, but got: cpu
  • Python multiprocessing expects a pickleable Dataset: AttributeError: Can't pickle local object 'main.<locals>.SequenceDataset'

If this is an inference task without the need to do things like broadcasting gradients, could you “outsource” the CPU inference to a subprocess subprocess — Subprocess management — Python 3.10.2 documentation e.g., via POpen as a workaround?

Thank you for pointing out that indeed I don’t need gradients, as the model is put into evaluation mode straight away.
About your solution, is it possible to use it without having to code additional scripts that get called via command line? I’m asking because that would require some coding on my part to generate the called script, and then merga again the CPU-processed sequences into the main dataset.

You might want to check out the multiprocessing module–either the torch version or the native version available in Python: Multiprocessing package - torch.multiprocessing — PyTorch 1.10.1 documentation

+1 to delegating cpu inference to a subprocess as mentioned by eqy. Additionally you can look into using the RPC framework for this as well (Distributed RPC Framework — PyTorch master documentation) which allows you to send tensors between processes easily.

For your use case, here is a basic example of spawning multiple processes programmatically (without command line) and provides a basic example of using RPC to delegate tensors to CPU or GPU model based on the contents of the tensor.

import os
import random
import torch
import torch.multiprocessing as mp
import torch.distributed.rpc as rpc
import torch.nn as nn
import time

os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "5678"

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.layer = nn.Linear(1, 1)

    def forward(self, x):
        return self.layer(x)

model = Model()
shutdown = 0

def set_shutdown():
    global shutdown
    shutdown = 1

def process_data(input, use_cuda):
    if use_cuda:
        input ="cuda:0")

    # return tensor output as cpu
    return model(input).cpu()

def worker_0():
    global shutdown
    rpc.init_rpc("cpu_worker", rank=0, world_size=3)

    # spins there until shutdown is set
    while not shutdown:

def worker_1():
    global shutdown
    rpc.init_rpc("cuda_worker", rank=1, world_size=3)

    # spins there until shutdown is set
    while not shutdown:

def worker_2():
    rpc.init_rpc("controller", rank=2, world_size=3)
    data = torch.rand(10, 1)
    futures = []
    for rand_tensor in data:
        # if random number is > 0.5, then use cpu to process, else use gpu
        if rand_tensor > 0.5:
            fut = rpc.rpc_async("cpu_worker", process_data, args=(rand_tensor, False))
            fut = rpc.rpc_async("cuda_worker", process_data, args=(rand_tensor, True))
    for fut in futures:
        result = fut.wait()
    rpc.rpc_sync("cpu_worker", set_shutdown)
    rpc.rpc_sync("cuda_worker", set_shutdown)

if __name__ == "__main__":
    p0 = mp.Process(target=worker_0)
    p1 = mp.Process(target=worker_1)
    p2 = mp.Process(target=worker_2)