Hi PyTorch community, I hope this is the right section to post my issue in.
I have a quite big dataset of biological (protein) sequences, around 68M, that I have to process with Rostlab/prot_t5_xl_uniref50 at main, a pre-trained NLP model that generates embeddings for downstream tasks. I have set up a decent workflow that is able to recover when my single V100 GPU goes OOM when the sequence is too big (even with batch size set to 1), effectively resorting to CPU and RAM for the calculations.
As of now, I used the suggested PyTorch implementation in the for loop:
This approach solves the OOM issues, but I realized that has the side effect of having the GPU sitting around doing nothing while the big sequence is being processed, only to resume working with the next iterations.
Is there any possibility to send the problematic sequence to CPU while continuing to process the rest normally?
Things I’ve tried:
DataParallel is not the right answer:ValueError: Expected a non cpu device, but got: cpu
Python multiprocessing expects a pickleable Dataset: AttributeError: Can't pickle local object 'main.<locals>.SequenceDataset'
Thank you for pointing out that indeed I don’t need gradients, as the model is put into evaluation mode straight away.
About your solution, is it possible to use it without having to code additional scripts that get called via command line? I’m asking because that would require some coding on my part to generate the called script, and then merga again the CPU-processed sequences into the main dataset.
+1 to delegating cpu inference to a subprocess as mentioned by eqy. Additionally you can look into using the RPC framework for this as well (Distributed RPC Framework — PyTorch master documentation) which allows you to send tensors between processes easily.
For your use case, here is a basic example of spawning multiple processes programmatically (without command line) and provides a basic example of using RPC to delegate tensors to CPU or GPU model based on the contents of the tensor.
import os
import random
import torch
import torch.multiprocessing as mp
import torch.distributed.rpc as rpc
import torch.nn as nn
import time
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "5678"
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.layer = nn.Linear(1, 1)
def forward(self, x):
return self.layer(x)
model = Model()
shutdown = 0
def set_shutdown():
global shutdown
shutdown = 1
def process_data(input, use_cuda):
if use_cuda:
input = input.to("cuda:0")
model.cuda()
else:
model.cpu()
print(next(model.parameters()).device)
# return tensor output as cpu
return model(input).cpu()
def worker_0():
global shutdown
rpc.init_rpc("cpu_worker", rank=0, world_size=3)
# spins there until shutdown is set
while not shutdown:
continue
rpc.shutdown()
def worker_1():
global shutdown
rpc.init_rpc("cuda_worker", rank=1, world_size=3)
# spins there until shutdown is set
while not shutdown:
continue
rpc.shutdown()
def worker_2():
rpc.init_rpc("controller", rank=2, world_size=3)
data = torch.rand(10, 1)
futures = []
for rand_tensor in data:
# if random number is > 0.5, then use cpu to process, else use gpu
if rand_tensor > 0.5:
fut = rpc.rpc_async("cpu_worker", process_data, args=(rand_tensor, False))
else:
fut = rpc.rpc_async("cuda_worker", process_data, args=(rand_tensor, True))
futures.append(fut)
for fut in futures:
result = fut.wait()
print(result)
rpc.rpc_sync("cpu_worker", set_shutdown)
rpc.rpc_sync("cuda_worker", set_shutdown)
rpc.shutdown()
if __name__ == "__main__":
p0 = mp.Process(target=worker_0)
p1 = mp.Process(target=worker_1)
p2 = mp.Process(target=worker_2)
p0.start()
p1.start()
p2.start()
p0.join()
p1.join()
p2.join()