Multiprocessing slows dataloaders at the each beginning epoch

Hi, I’m currently using torch.multiprocessing for sending the outputs of a neural network to another process. For simple discussion, I have two processes: the first one is for loading training data, forwarding network and sending the results to the other one, while the other one is for recving the results from the previous process and handling the results.

The question I get troubled in is that, multiprocesses slow down the data loader. At every begining of each epoch, the forwarding process will unexpectedly get stucked for a while and then continue forwarding. When I cancel the multiprocess (i.e. just forward network and don’t send data across processes), all things go smoothly, and the stucked phenomenon at every begining epoch disappears.

Below is a python3 simple script (runs on one GTX 1080) which loads a pretrained network and use the dataset of Cifar10 with arguments of num_workers=2 and pin_memory=True. You will observe the program gets stucked between for epoch in range(total_epoch) and for step, datas in enumerate(trainloader)

Anyone has any idea on why this happens and how to solve this? (I am doubting it is caused by limits of shared memory, but still don’t know how to avoid this)

import torch.multiprocessing as mp

def main():

  os.environ['CUDA_VISIBLE_DEVICES'] = '0'
  device = torch.device('cuda:0')

  model_name = 'vgg16'
  dataset = 'cifar10'

  total_epoch = 20
  batch_size = 128

  # get model and datasets
  model = get_a_pretrained_model
  trainloader = get_cifar10_trainloader # numworkers=2, pin_memory=True
  for param in model.parameters():
    param.requires_grad = False
  model = model.to(device)
  model.eval()

  mp.set_start_method('spawn')
  buffer = mp.Queue()
  recv_proc = mp.Process(target=recv, args=(buffer,))
  recv_proc.start()

  # forwarding model and send results
  for epoch in range(total_epoch):
    # program will get stucked here
    for step, datas in enumerate(trainloader):
      inputs, labels = datas
      inputs, labels = inputs.to(device), labels.to(device)
      outs = model(inputs)
      buffer.put(outs)
      if step % 20 == 0:
        print('Epoch [{}]\t Step [{}]'.format(epoch, step))
  buffer.put(None)

  recv_proc.join()
  

def recv(buffer):
  while True:
    recved = buffer.get()
    if recved is None:
      break
    # do something with recved
    del recved


if __name__ == '__main__':
  main()

I encountered the same problem. Have you solved this problem?

1 Like

Same problem encountered, can anybody answer it?

update pytorch to 1.7, then the arg persistent_workers will be helped. See this link What are the (dis) advantages of persistent_workers

1 Like

Thank you so much man, been facing this problem for a month, so glad to find you answer!

I changed the start method to fork mp.set_start_method(‘fork’). It solved the problem.