Can torch.multiprocessing and torch.distributed be used within forward()?

Hi @osalpekar, thank you for your response. I am looking to parallelize processing in the forward pass with the workers within a single node.
My bottleneck is not the batch processing (choosing different batch sizes has little effect on the time spent in the forward pass). Rather, most of the time is spent on the part of the forward where I have a for loop, which I want to parallelize.

Is there any more info you can provide about the function you are trying to parallelize in the forward pass?

For example, when x is a tensor with an MNIST batch, and f_i are arbitrary functions with distinct learnable parameters. The f_i’s take as input a tensor of size (N,M) → (N,) and in the forward, we have:

x=torch.unfold(x,kernel_size=2,stride=2) #would give for MNIST size [N,196,4]
x=torch.cat([f_i(x[..., i]) for i in range(x.shape[-1])], dim=1) # is very slow!

Here, if I process 1 image or many more in a batch makes little difference. Every function f_i processes the whole batch, but only a part of the image. You could, I guess, think of it as the f_i’s being local receptive fields and we can parallelize the processing of their activations. Then, x passes to other modules in the model.
The closest I have seen to my question is this and this. However, the former seems to be a parallelization on the data, while in the latter, more similar to what I intend, there is no solution.

A search indicated that this error is thrown when you attempt to pickle the pool object itself (say the function you are trying to parallelize results in pickling the pool).

I have tried this, and this seems to avoid pickling the pool object, but the program just stops responding at some point, so I guess this is not possible.

Sorry for the long answer. Any comment would be very appreciated!