Iterable dataloader threading errror Error


import math

class MyIterableDataset(torch.utils.data.IterableDataset):
    def __init__(self, start, end):
        super(MyIterableDataset).__init__()
        assert end > start, "this example code only works with end >= start"
        self.start = start
        self.end = end
    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        if worker_info is None:
            print("worker_info is none")  # single-process data loading, return the full iterator
            iter_start = self.start
            iter_end = self.end
        else:  # in a worker process
            print("worker_info is something:", worker_info) 
            # split workload
            per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
            worker_id = worker_info.id
            iter_start = self.start + worker_id * per_worker
            iter_end = min(iter_start + per_worker, self.end)
        return iter(range(iter_start, iter_end))
# should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
ds = MyIterableDataset(start=3, end=7)

# Single-process loading
print(list(torch.utils.data.DataLoader(ds,)))

i get error: /usr/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning:

os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.

This is something im experiencing with iterable dataset when num_workers>0 how do i fix this? and its slower for more workers?

You might want to remove Jax code from your PyTorch script assuming you are not using it.

there is no Jax code I borrowed this code from iterable torch doc’s i don’t know what causes this

The warning seems to be raised from jax/_src/xla_bridge.py, so you might want to check which line of code raises it or where the import might be.

i’ve tried searching to no avail would you recon maybe its python and i should try to update python? even in the code as you can see there is no jax, i think this is an issue with the iterable dataset

No, I doubt the error is raised by Python itself, as it would be unaware of Jax.
Your code is missing the import torch statement, and after adding it I do not see any warning on my system but only this output:

worker_info is none
[tensor([3]), tensor([4]), tensor([5]), tensor([6])]

After installing Jax and adding:

import jax
jax._src.xla_bridge.backends()

I see the same warning.

1 Like

you sir are a genius!!, I unistalled jax and it worked. I had jax pre-installed in my machine learning VM.

1 Like