Evaluate multiple models on multiple GPUs

Hello guys,

I would like to do parallel evaluation of my models on multiple GPUs. I don’t have much experience using python and pytorch this way. Here is a pseudocode of what I’m trying to do:

import torch
import torch.multiprocessing as mp
from mycnn import CNN
from data_parser import parser
from fitness import get_fitness  # this also runs on GPU

def run_model(outputs, model, device_id, input):
    out = model(input)
    f = get_fitness(out)    # due to this I cannot just run: model(input, non_blocking=True)
    outputs[device_id] = f.cpu()

if __name__ == '__main__':

batch = parser.get_batch()
model = CNN()
GPU_NUM = 2
outputs = torch.zeros(GPU_NUM, dtype=torch.double)  # collect outputs here
outputs.share_memory_()     # I guess this is not enough to make it work
mp.set_start_method('spawn')
processes = []

for dev_id in (range(GPU_NUM)):
    device = torch.device("cuda:" + str(dev_id))
    dev_model = model.to(device)
    dev_batch = batch.to(device)
    p = mp.Process(target=run_model, args=(outputs, dev_model, dev_id, dev_batch))
    p.start()
    processes.append(p)
for p in processes:
    p.join()

print(outputs)

Sadly this doesn’t work at all and I’m probably doing it completely wrong. I’m getting this error:

OSError: [Errno 12] Cannot allocate memory

For some reason it drains all memory on my server, even when GPU_NUM = 1. When I run code synchronously I get no errors. Could you please tell me what is the right way to do something like this or give me link to some examples, which would help me?

If you have multiple GPUs and want to evaluate each model on a single dedicated GPU independently, you could just push each model to a GPU via:

modelA = modelA.to('cuda:0')
modelB = modelB.to('cuda:1')
...

and evaluate each model separately.
Since CUDA calls are asynchronous, the GPUs shouldn’t block each other, if you don’t synchronize them manually.

If the cannot allocate memory issues still persist, you might try allocating additional swap space on your machine (see https://github.com/pytorch/pytorch/issues/4387 for more details).