For some reason, I need to use a for loop to process each batch differently in training. To make it faster, I want to use multiprocessing to deploy different batch on different process. However, I found that using multiprocessing is even slower than a simple for loop. What’s wrong with my code?
Here’s my example code:
from torch.multiprocessing import Pool, Process, set_start_method
try:
set_start_method('spawn', force=True)
except RuntimeError:
pass
from torch import nn
import torch
from tqdm import tqdm
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.model = nn.Linear(128, 128)
def forward(self, x, pool):
output = self.model(x)
### select multiprocess or for loop here ###
self.multiprocess(x, pool)
# self.singleprocess(x, pool)
return output
def f(self, x):
return x ** 2
def multiprocess(self, x, pool):
res = list(pool.imap(self.f, x))
def singleprocess(self, x, pool):
for idx in range(len(x)):
x[idx] = x[idx] ** 2
if __name__ == '__main__':
model = Model()
pool = Pool(processes=8)
for iter in tqdm(range(1000)):
### fake data ###
x = torch.ones(100, 128)
output = model(x, pool)
The results shows that multiprocessing is at least 36 times slower than for loop.