Multiprocessing is much slower than for loop in forward

For some reason, I need to use a for loop to process each batch differently in training. To make it faster, I want to use multiprocessing to deploy different batch on different process. However, I found that using multiprocessing is even slower than a simple for loop. What’s wrong with my code?

Here’s my example code:

from torch.multiprocessing import Pool, Process, set_start_method
    set_start_method('spawn', force=True)
except RuntimeError:
from torch import nn
import torch
from tqdm import tqdm

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.model = nn.Linear(128, 128)

    def forward(self, x, pool):
        output = self.model(x)
        ### select multiprocess or for loop here ###
        self.multiprocess(x, pool)
#        self.singleprocess(x, pool)

        return output 

    def f(self, x):
        return x ** 2

    def multiprocess(self, x, pool):
        res = list(pool.imap(self.f, x))

    def singleprocess(self, x, pool):
        for idx in range(len(x)):
            x[idx] = x[idx] ** 2

if __name__ == '__main__':
    model = Model()
    pool = Pool(processes=8)
    for iter in tqdm(range(1000)):
        ### fake data ###
        x = torch.ones(100, 128)
        output = model(x, pool)

The results shows that multiprocessing is at least 36 times slower than for loop.