I’ve been trying to set up parallelisation for an object detection model I’ve trained, in order to improve the throughput of the model when running on CPU. To do this, I’m roughly following this blog post on implementing Hogwild in PT.
Unfortunately, when running my script, the processes appear to hang while trying to iterate through the DataLoader. Iterating through the DataLoader before calling mp.Process
works as expected, but iterating within the process causes the program to freeze.
I’ve provided a minimal example below:
import torch
import torch.multiprocessing as mp
from torch.utils.data import TensorDataset, DataLoader
from torch.utils.data.distributed import DistributedSampler
def loop_thru(dataloader):
for i, _ in enumerate(dataloader):
print(i)
if __name__ == "__main__":
# Settings
num_processes = 1
# Create data
data = torch.ones((1024,3,352,352)).float() / 255
dataset = TensorDataset(data)
# Start processes to loop through data
processes = []
for rank in range(num_processes):
sampler = DistributedSampler(
dataset=dataset, num_replicas=num_processes, rank=rank, shuffle=False
)
dataloader = DataLoader(
dataset=dataset,
sampler=sampler,
batch_size=16,
)
p = mp.Process(target=loop_thru, args=(dataloader,))
p.start()
processes.append(p)
for p in processes:
print(f"Joining: {p}")
p.join()
When I run this script, the outputs are as follows before the program hangs:
Any idea what I’m doing wrong here?
Edit: Running pytorch=1.10.1