Can torch.multiprocessing and torch.distributed be used within forward()?

I am trying to implement a parallel evaluation of a function on different sections of my data within the forward() in my model. I am not even sure if this is possible.
I saw that there is a torch.multiprocessing.Pool that I can use to map a function to a list of tensors, but when used within my forward method, it complains because I cannot use pool objects within a class (apparently):
NotImplementedError: pool objects cannot be passed between processes or pickled.

Here more or less what I would like to try:

def forward(self,x):
    x = nf.unfold(x) #unfold the e.g. image to get patches
    x = evaluate_function_in_parallel(x) # parallelize this evaluation e.g. x = pool.map(function,x)
    x = torch.cat(x)
    return x

I have only seen examples of distributed training with torch.multiprocessing and torch.distributed but not examples for distributing the work within the forward function. Is it even possible? If so, are there any examples available?

Any comment on this would be really helpful. Thanks.

@Juans Are you trying to use multiprocessing to parallelize processing in the forward pass on a single node, or are you trying to do this in a distributed setting?

Is there any more info you can provide about the function you are trying to parallelize in the forward pass? A search indicated that this error is thrown when you attempt to pickle the pool object itself (say the function you are trying to parallelize results in pickling the pool).

I am looking for some examples of this behavior, but the recommended method would be to parallelize the entire training iteration (the entire forward pass and backward pass on a single batch) using DDP instead of just parallelizing one part of the fwd pass.

Hi @osalpekar, thank you for your response. I am looking to parallelize processing in the forward pass with the workers within a single node.
My bottleneck is not the batch processing (choosing different batch sizes has little effect on the time spent in the forward pass). Rather, most of the time is spent on the part of the forward where I have a for loop, which I want to parallelize.

Is there any more info you can provide about the function you are trying to parallelize in the forward pass?

For example, when x is a tensor with an MNIST batch, and f_i are arbitrary functions with distinct learnable parameters. The f_i’s take as input a tensor of size (N,M) → (N,) and in the forward, we have:

x=torch.unfold(x,kernel_size=2,stride=2) #would give for MNIST size [N,196,4]
x=torch.cat([f_i(x[..., i]) for i in range(x.shape[-1])], dim=1) # is very slow!

Here, if I process 1 image or many more in a batch makes little difference. Every function f_i processes the whole batch, but only a part of the image. You could, I guess, think of it as the f_i’s being local receptive fields and we can parallelize the processing of their activations. Then, x passes to other modules in the model.
The closest I have seen to my question is this and this. However, the former seems to be a parallelization on the data, while in the latter, more similar to what I intend, there is no solution.

A search indicated that this error is thrown when you attempt to pickle the pool object itself (say the function you are trying to parallelize results in pickling the pool).

I have tried this, and this seems to avoid pickling the pool object, but the program just stops responding at some point, so I guess this is not possible.

Sorry for the long answer. Any comment would be very appreciated!

@Juans Thanks for clarifying your use case.

It sounds like multiprocessing.Pool.map (or something similar) would allow you to map a function like the f_i in your example on some chunk of the input tensor to a process pool that you can define. The exact chunksize can be configured using args like chunksize or a related operation called imap. Will this serve your purpose?

I tried the following (very simplistic) example using Pool.map in the model’s forward pass and it seems to work:

import torch
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp

class DemoMultiProcModel(nn.Module):
    def __init__(self):
        super(DemoMultiProcModel, self).__init__()
        self.net1 = nn.Linear(10, 10)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(10, 5)
        
    def forward(self, x):    
        with mp.Pool(5) as p:
            self.result = p.map(sum, x)
        return self.net2(self.relu(self.net1(x)))


def train():
    print(f"Running Model with parallelized forward pass.")
    model = DemoMultiProcModel()

    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(model.parameters(), lr=0.001)
    
    for batch_idx in range(5):
        # Generate random inputs/labels
        inputs = torch.randn(20, 10)
        labels = torch.randn(20, 5)
        
        # Train
        optimizer.zero_grad()
        outputs = model(inputs)
        loss_fn(outputs, labels).backward()
        optimizer.step()
        print(f"Batch {batch_idx} done.")
1 Like

@osalpekar Thanks a lot for your suggestion.
I had considered torch.multiprocessing.Pool.map since, as you correctly pointed out, seemed to be the solution.
Unfortunately, I see two issues with this (please find the adapted code below):
First: somehow pool.map seems to take quite some time (100x more than using a for-loop in this case, roughly 0.3 sec, which is around 3x longer than in my original problem). Do you know of any problems with this method? It seems strange, if you compare it with the built-in map function. Could it be some sort of overhead? seems quite high.
Second: Indeed your suggestion works, I made some small changes, but it doesn’t raise any errors, or freezes. However, this solution does not translate to the case where you have registered parameters in the functions f_i. It throws following error:

multiprocessing.pool.MaybeEncodingError: Error sending result
Reason: 'RuntimeError('Cowardly refusing to serialize non-leaf tensor which 
requires_grad, since autograd does not support crossing process boundaries.  
If you just want to transfer the data, call detach() on the tensor before 
serializing (e.g., putting it on the queue).')'

So it is an issue with autograd. I have read this error already here and since, it is relatively new to have distributed autograd, I’d need to dive into the docs. Any pointers are welcome!
I tried it in version 1.6.0 and 1.7.0.

Adapted code

import torch
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp
import time

class DemoMultiProcModel(nn.Module):
    def __init__(self):
        super(DemoMultiProcModel, self).__init__()
        self.net1 = nn.Linear(196, 10)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(10, 5)

        self.kernel = nn.Linear(4, 1) #let's try to make a feature map

    def func_to_apply(self, data):
        return self.kernel(data)
        # return torch.sum(data, 1)[None]

    def forward(self, x):
        start = time.time()
        with mp.Pool(5) as p:
            result = p.map(self.func_to_apply, x)
        # result = []
        # for x_i in x:
        #     result.append(self.func_to_apply(x_i))
        print(f'Time it takes whole op: {time.time()-start} s')
        x = torch.cat(result, dim=0)
        return self.net2(self.relu(self.net1(x)))


def train():
    print(f"Running Model with parallelized forward pass.")
    model = DemoMultiProcModel()

    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(model.parameters(), lr=0.001)

    for batch_idx in range(5):
        # Generate random inputs/labels
        inputs = torch.randn(256, 196, 4)
        labels = torch.randn(256, 5)

        # Train
        optimizer.zero_grad()
        outputs = model(inputs)
        loss_fn(outputs, labels).backward()
        optimizer.step()
        print(f"Batch {batch_idx} done.")


if __name__ == '__main__':
    train()

I see, I can confirm repro’ing this error case. It seems like autograd cannot handle computing gradients on tensors that have been created by operations that involve cross-process communication, so it simply refuses to serialize them for IPC in the fwd pass itself.

@pritamdamania87 Is our assessment here correct? Is this use case supported with Distributed Autograd/some other alternative? I’m guessing one way of doing this would be to send RPC’s in the fwd pass to other processes to perform the func_to_apply on some chunk of data, and then collect the results (and distributed autograd would do the reverse in the bwd pass), but not sure if this is feasible/the best approach.

@osalpekar I can try the rpc approach and report it here.
I guess this use case is not so common (although I saw some similar ones in the forum already, as I mentioned above), so I hope it works.
Any further comments are of course welcome!

Just spoke about your use case with a few other folks, and the RPC/Distributed Autograd-based mechanism I described above should work.

Alternately, if you can Torchscript your model (see docs here), you may be able to use torch.jit.fork to get this multiprocessing-style parallelism. There are a number of other performance benefits of using Torchscript (such as bypassing the Python Global Interpreter Lock) as well.

1 Like

This is great news! Thanks a lot for your help.
I’ll have a look into it. Hopefully I can come up with a solution to post here. Maybe it is also useful to others.