How do I map Joblib's Parallel function to PyTorch's DistributedDataParallel

I have the following code below which uses Joblib’s Parallel and I want to implement this in PyTorch and run it with GPUs. I am reading through PyTorch’s DistributedDataParallel documentation, but I can’t seem to figure this out.

import numpy as np
import torch
from joblib import Parallel, delayed
from torch.nn.parallel import DistributedDataParallel as DDP

X = np.array([[1, 3, 2, 3], [2, 3, 5, 6], [1, 2, 3, 4]])
X = torch.DoubleTensor(X).cuda()

def X_power_func(j):
    X_power = X**j
    return X_power

results = Parallel(n_jobs = 4)(delayed(X_power_func)(j) for j in range(8))   # how do I map this to 
                                                                             # PyTorch's
                                                                             # DistributedDataParallel

Any help would really be appreciated. Many thanks in advance!

use torch.multiprocessing.pool

1 Like

Thanks @iffiX. Do you know in which situations that we would use torch.multiprocessing and DistributedDataParallel?

DistributedDataParallel is designed for asynchronously let the model perform forward and backward process, internnaly it synchronously perform gradient reduction and parameter updating.

torch.multiprocessing is a simple derivative of the vanilla multiprocessing module, it only replaces the default queue implementation used in the vanilla module, and implements an efficient way to pass around cuda tensors (data remains on gpu, only a pointer to data is passed to subprocess pool workers).

Pool is designed for carrying out general unit tasks by a group of homogeneous workers with no context, such as your:

def X_power_func(j):
    X_power = X**j
    return X_power

Pool is essentially the same as joblib

1 Like

Ok many thanks @iffiX for the detailed answer.

So essentially use DistributedDataParallel for neural network stuff (which involves things like forward and backward processes) that you want to parallelize, and use torch.multiprocessing for non-neural network stuff that you want to parallelize.

thats correct! :slightly_smiling_face:

1 Like