Unexpected behavior when using two multiprocessing dataloaders

Hi,
I am training a model on two different tasks (that require different data) in parallel. I’m using 2 data loaders one for each dataset where one dataset set is Iterable and the other is map-style. My training loop looks something like this

dataloader2_iterator = iter(dataloader2)
for batch1 in dataloader1:
batch2 = next(dataloader2_iterator)
train_model(batch1, batch2)

When both dataloaders are multiprocess (i.e. num_workers > 0) training works fine for a while then crashes with the following error:
File “/home/gamir/DER-Roei/alon/anaconda3/envs/open_clip/lib/python3.10/site-packages/torch/utils/data/dataloader.py”, line 652, in next
data = self._next_data()
File “/home/gamir/DER-Roei/alon/anaconda3/envs/open_clip/lib/python3.10/site-packages/torch/utils/data/dataloader.py”, line 1347, in _next_data
return self._process_data(data)
File “/home/gamir/DER-Roei/alon/anaconda3/envs/open_clip/lib/python3.10/site-packages/torch/utils/data/dataloader.py”, line 1373, in _process_data
data.reraise()
File “/home/gamir/DER-Roei/alon/anaconda3/envs/open_clip/lib/python3.10/site-packages/torch/_utils.py”, line 461, in reraise
raise exception
RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
File “/home/gamir/DER-Roei/alon/anaconda3/envs/open_clip/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py”, line 302, in _worker_loop
data = fetcher.fetch(index)
File “/home/gamir/DER-Roei/alon/anaconda3/envs/open_clip/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py”, line 52, in fetch
return self.collate_fn(data)
File “/home/gamir/DER-Roei/alon/anaconda3/envs/open_clip/lib/python3.10/site-packages/torch/utils/data/_utils/collate.py”, line 180, in default_collate
return [default_collate(samples) for samples in transposed] # Backwards compatibility.
File “/home/gamir/DER-Roei/alon/anaconda3/envs/open_clip/lib/python3.10/site-packages/torch/utils/data/_utils/collate.py”, line 180, in
return [default_collate(samples) for samples in transposed] # Backwards compatibility.
File “/home/gamir/DER-Roei/alon/anaconda3/envs/open_clip/lib/python3.10/site-packages/torch/utils/data/_utils/collate.py”, line 146, in default_collate
return torch.stack(batch, 0, out=out)
RuntimeError: torch.cat(): input types can’t be cast to the desired output type Long

When only one of the dataloaders is multiprocess and the other is single process this doesn’t happen and training works fine all the way. I’ve debugged everything and there seems to be no reason for this error besides possible collision between the workers of the dataloaders when both are multiprocesses.

Does anyone know might be going wrong here and how I can fix this?
Thanks

What PyTorch version are you using?