Simple parallel GPU inference

I have a trained model and 4 GPUs available. I have a dataset which I want to split in 4 (and process with the same batch size on each GPU) independently of each other and essentially add the results I get from each GPU. I did not get too much wiser from e.g. How do I run Inference in parallel?. Please not that I am only doing inference with my model, and no gradient computations etc are required.

A minimal example of what I’m trying to do is this:

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import os
from torch.nn.parallel import DistributedDataParallel as DDP
torch.random.manual_seed(123)

input_dim, out_dim = 10, 1
net = nn.Linear(input_dim, out_dim) # I load my model from a saved state_dict

m = torch.cuda.device_count()
n = 5 # number of data pts sent to each GPU
x = torch.rand((n,input_dim,m)) # full data set

# without parallel processing:
s0 = 0 
for i in range(m):
    s0 += net(x[:,:,i]).sum()
print('s0', s0)

# with parallel processing:    
def example(rank, world_size):
    print('rank', rank)
    # create default process group
    dist.init_process_group("gloo", rank=rank, world_size=world_size)
    # create local model
    model = net.to(rank)
    # construct DDP model
    ddp_model = DDP(model, device_ids=[rank])
    # forward pass
    x_local = x[:,:,rank].to(rank)
    outputs = ddp_model(x_local)
    print(outputs.sum()) # these add up to the desired value s0 
    # but how do I return these values from each GPU process and add them? 

def main():
    world_size = torch.cuda.device_count()
    mp.spawn(example,
        args=(world_size,),
        nprocs=world_size,
        join=True)

if __name__=="__main__":
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "29500"
    main()

I’ve seen examples where people do all of the data/model preparation inside the example function, but I don’t know why.
Also, I feel like there should be a cleaner way of loading the data using DataLoader which would work smoother with DPP, so pointers for this are also very welcome. I also posted this question to stackexchange but got no replies.

Your code example looks right. If you need to add the outputs across GPUs you need to do some sort of communication, e.g. all_reduce to sum across different ranks: Distributed communication package - torch.distributed — PyTorch main documentation

For loading the data you can use the regular pytorch dataloader, but in the sampler, specify the DistributedSampler torch.utils.data — PyTorch 2.3 documentation

Edited post with Pipe solution:

Thanks for the pointers. I have found an acceptable solution using mp.Pipe instead, which seems to be needed since I want to get back the results from the individual GPUs to the main scope (do correct me if I’m wrong). Also, there is a mysterious issue (to me) when trying to send tensors to parent_conn, please see below.

import os, sys 
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP

def get_model_and_data(N=12, num_devices=2):
    torch.random.manual_seed(0)
    torch.set_default_tensor_type('torch.DoubleTensor') # for more exact comparisons
    net = nn.Linear(1, 1)
    x = torch.linspace(0, 1, N).unsqueeze(1)
    if (N % num_devices) != 0:
        sys.exit(f'Full tensor length {N} not evenly divisible over {num_devices} GPUs.')
    return net, x.reshape(num_devices, -1, 1)

def inference(rank, world_size, net, x, conn):
    torch.set_printoptions(8)
    torch.set_default_tensor_type('torch.DoubleTensor')
    dist.init_process_group('nccl', rank=rank, world_size=world_size)
    xloc = x[rank].to(rank)
    model = net.to(rank)
    model = DDP(model, device_ids=[rank])
    output = model(xloc).sum().detach() # some operation I want to do on each GPU (no grads required)
    print('Rank ', rank, 'partial sum:', output) # partial sum from each process 
    r = float(output)
    conn.send((rank, r)) # this works OK 
    #conn.send(output) # this yields warnings when sending and error when recieving
          
def main():
    torch.set_printoptions(8)
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "29500"
    net, x = get_model_and_data()
    s0 = net(x).sum() # compute without parallel processing
    world_size = x.shape[0]
    
    parent_conn, child_conn = mp.Pipe()
    mp.spawn(inference,
        args=(world_size, net, x, child_conn),
        nprocs=world_size,
        join=True)
    
    res = []
    while parent_conn.poll():
        c = parent_conn.recv()
        print(f'recieving {c}')
        res.append(c)
    print(res)
        
        
if __name__=="__main__":
    main()

However, this only works if I cast output to e.g. float. If I do conn.send(output) then I get a warning during sending and eventually an error when receiving. The same behavior results if
output is replaced with simply e.g. torch.ones(1) or torch.ones(1).to(rank). So if anyone knows how to properly send tensors this way I would like to know.

Rank  0 partial sum: tensor(3.77580110, device='cuda:0')
Rank  1 partial sum: tensor(6.85251166, device='cuda:1')
[W CudaIPCTypes.cpp:15] Producer process has been terminated before all shared CUDA tensors released. See Note [Sharing CUDA tensors]
[W CudaIPCTypes.cpp:15] Producer process has been terminated before all shared CUDA tensors released. See Note [Sharing CUDA tensors]
recieving (0, 3.7758011049971456)
Traceback (most recent call last):...
RuntimeError: CUDA error: invalid resource handle
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.