Hi, I’m currently using torch.multiprocessing
for sending the outputs of a neural network to another process. For simple discussion, I have two processes: the first one is for loading training data, forwarding network and sending the results to the other one, while the other one is for recving the results from the previous process and handling the results.
The question I get troubled in is that, multiprocesses slow down the data loader. At every begining of each epoch, the forwarding process will unexpectedly get stucked for a while and then continue forwarding. When I cancel the multiprocess (i.e. just forward network and don’t send data across processes), all things go smoothly, and the stucked phenomenon at every begining epoch disappears.
Below is a python3 simple script (runs on one GTX 1080) which loads a pretrained network and use the dataset of Cifar10 with arguments of num_workers=2
and pin_memory=True
. You will observe the program gets stucked between for epoch in range(total_epoch)
and for step, datas in enumerate(trainloader)
Anyone has any idea on why this happens and how to solve this? (I am doubting it is caused by limits of shared memory, but still don’t know how to avoid this)
import torch.multiprocessing as mp
def main():
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
device = torch.device('cuda:0')
model_name = 'vgg16'
dataset = 'cifar10'
total_epoch = 20
batch_size = 128
# get model and datasets
model = get_a_pretrained_model
trainloader = get_cifar10_trainloader # numworkers=2, pin_memory=True
for param in model.parameters():
param.requires_grad = False
model = model.to(device)
model.eval()
mp.set_start_method('spawn')
buffer = mp.Queue()
recv_proc = mp.Process(target=recv, args=(buffer,))
recv_proc.start()
# forwarding model and send results
for epoch in range(total_epoch):
# program will get stucked here
for step, datas in enumerate(trainloader):
inputs, labels = datas
inputs, labels = inputs.to(device), labels.to(device)
outs = model(inputs)
buffer.put(outs)
if step % 20 == 0:
print('Epoch [{}]\t Step [{}]'.format(epoch, step))
buffer.put(None)
recv_proc.join()
def recv(buffer):
while True:
recved = buffer.get()
if recved is None:
break
# do something with recved
del recved
if __name__ == '__main__':
main()