For some reason, I need to use a
for loop to process each batch differently in training. To make it faster, I want to use
multiprocessing to deploy different batch on different process. However, I found that using multiprocessing is even slower than a simple for loop. What’s wrong with my code?
Here’s my example code:
from torch.multiprocessing import Pool, Process, set_start_method try: set_start_method('spawn', force=True) except RuntimeError: pass from torch import nn import torch from tqdm import tqdm class Model(nn.Module): def __init__(self): super(Model, self).__init__() self.model = nn.Linear(128, 128) def forward(self, x, pool): output = self.model(x) ### select multiprocess or for loop here ### self.multiprocess(x, pool) # self.singleprocess(x, pool) return output def f(self, x): return x ** 2 def multiprocess(self, x, pool): res = list(pool.imap(self.f, x)) def singleprocess(self, x, pool): for idx in range(len(x)): x[idx] = x[idx] ** 2 if __name__ == '__main__': model = Model() pool = Pool(processes=8) for iter in tqdm(range(1000)): ### fake data ### x = torch.ones(100, 128) output = model(x, pool)
The results shows that
multiprocessing is at least 36 times slower than