DataLoader, when num_worker >0, there is bug

Hi!

@ptrblck any update on this? Any words of wisdom how to work with HDF5 files efficiently using multiple processes?

Thanks,
Piotr

Hope to see any better solution to work with hdf5 files.

Unfortunately I haven’t looked into this issue further, as the suggested solution from @sytrus-pytorch seems to run correctly. I’m not sure how large the overhead would be to reopen the HDF5 file continuously.
Did you try this approach?

Yea, I’ve explored topic a bit and what I found is:

  • With version 1.8 of HDF5 library working with HDF5 files and multiprocessing is a lot messier (not h5py! I mean HDF5 library installed on your system: https://unix.stackexchange.com/questions/287974/how-to-check-if-hdf5-is-installed). I highly recommend to update the library to 1.10 version where multiprocessing works better. I was only able to get h5py to work with “with” statement and this seems to give huge overhead, but I didn’t have time to investigate it properly:
class H5Dataset(Dataset):
    def __init__(self, h5_path):
        self.h5_path = h5_path

    def __getitem__(self, index):
        with h5py.File(self.h5_path, 'r') as file:
            # Do something with file and return data

    def __len__(self):
        with h5py.File(self.h5_path,'r') as file:
            return len(file["dataset"])
  • In version 1.10 of HDF5 library I was able to create h5py.File once in __getitem__ and reuse it without errors.
class H5Dataset(Dataset):
    def __init__(self, h5_path):
        self.h5_path = h5_path
        self.file = None

    def __getitem__(self, index):
        if self.file is None:
            self.file = h5py.File(self.h5_path, 'r')
        # Do something with file and return data

    def __len__(self):
        with h5py.File(self.h5_path,'r') as file:
            return len(file["dataset"])
5 Likes

so then it seems to be a problem relating to hdf5 in general in combination with multirpocessing rather than with Pytorch itself? If so, it would be interesting so see, how this issue evolves when using the C++ frontend

So I investigated it further and in deed opening HDF5 introduces huge overhead. I’ve tested it on this code: https://github.com/piojanu/World-Models (my implementation of the World Models (further WM) paper, the memory training is written in PyTorch). Note: the code I link here doesn’t have multiprocessing data preloading capabilities, I test it in the private repo.
I use the Pyflame profiler to profile the WM’s memory module training for 30s with sampling every 1ms on HW: Intel® Core™ i7-7700 CPU @ 3.60GHz with GeForce GTX 1060 6GB.

Experiments:

  1. With data loading in main process (DataLoader’s num_worker = 0) and opening hdf5 file each time in __getitem__:
    • Batches per second: ~0,18
    • Most of the time data is being loaded, above 70% of the profiling time.
    • Opening the hdf5 file takes 20% of the profiling time!
    • Then we have data preprocessing and mem copy in last 10% of the profiling time.
    • Training one layer LSTM on the GPU is so fast, that the profiler didn’t catch it.
  2. With data loading in main process (DataLoader’s num_worker = 0) and opening hdf5 file once in __getitem__:
    • Batches per second: ~2
    • Still most of the time data is being loaded, ~90% of the profiling time.
    • There is no overhead from opening the hdf5 file of course, that’s why larger proportion of time went to loading the data.
    • Profiler was able to catch couple of samples of LSTM training, still below 1% of the profiling time.
  3. With data loading in worker processes (DataLoader’s num_worker = 4) and opening hdf5 file once in __getitem__:
    • Batches per second: ~5,1
    • There is no overhead from opening the hdf5 file and loading data is successfully covered with GPU execution. DataLoader’s __next__ operation (getting next batch) in main process takes below 1% of the profiling time and we have full utilisation of GTX1060! Win :wink:

My recommendations:

  • Use HDF5 in version 1.10 (better multiprocessing handling),
  • Because an opened HDF5 file isn’t pickleable and to send Dataset to workers’ processes it needs to be serialised with pickle, you can’t open the HDF5 file in __init__. Open it in __getitem__ and store as the singleton!. Do not open it each time as it introduces huge overhead.
  • Use DataLoader with num_workers > 0 (reading from hdf5 (i.e. hard drive) is slow) and batch_sampler (random access to hdf5 (i.e. hard drive) is slow).

Sample code:

class H5Dataset(torch.utils.data.Dataset):
    def __init__(self, path):
        self.file_path = path
        self.dataset = None
        with h5py.File(self.file_path, 'r') as file:
            self.dataset_len = len(file["dataset"])

    def __getitem__(self, index):
        if self.dataset is None:
            self.dataset = h5py.File(self.file_path, 'r')["dataset"]
        return self.dataset[index]

    def __len__(self):
        return self.dataset_len
38 Likes

thank you mate! A working but cumbersome workaround

1 Like

Thanks for the detailed profiling and analysis! The code looks great and I think it’ll be useful for a lot of users here in the board. :slight_smile:

2 Likes

a major problem, however, might be, that the entire data set is loaded into RAM, thus preventing slicing and an effective on-demand loading of data.

also, is there a reason why my GPU work load (1080TI) is droppling regularly for a short period time to 0 % before going up to 100% again when using 4 workers as compared to 0 where I did not observe this problem?

I don’t think that the entire data set is loaded into RAM. HDF5 file can keep TBs of data (or certainly more then you have RAM in your station) and this code should still work fine. HDF5 (with h5py interface in Python) will make sure to keep in RAM only needed data (currently accessed slice).
Of course, if you have too big batch size and DataLoader’s prefetch queue size (or some copying that isn’t freed/garbage collected), it might not fit into RAM. You can control batch size in DataLoader parameter.

Is the dropping happening before each new epoch? If yes, then I see it too. This is caused by DataLoader initialisation (forking processes/spawning workers, initialising queues etc.). This overhead isn’t observed when num_worker = 0 of course, when you don’t need to spawn workers etc.

5 Likes

thank you, your answers are very detailed and your contributions to this community are invaluable =)

I wonder, how this generalizes to the more common case in which there exist multiple data sets within the hdf5 file that needs to be accessed by the index. Is code duplication and loading of the respective sets in torch.tensors the way to go here?

Thank you for kind words :smile:

Here is an example of non-trivial Dataset that I use for preprocessing data for World Models’ memory module training:

class MemoryDataset(Dataset):
    """Dataset of sequential data to train memory.
    Args:
        dataset_path (string): Path to HDF5 dataset file.
        sequence_len (int): Desired output sequence len.
        terminal_prob (float): Probability of sampling sequence that finishes with
            terminal state.
        dataset_fraction (float): Fraction of dataset to use during training, value range: (0, 1]
            (dataset forepart is taken).
        is_deterministic (bool): If return sampled latent states or mean latent states.
    Note:
        Arrays should have the same size of the first dimension and their type should be the
        same as desired Tensor type.
    """

    def __init__(self, dataset_path, sequence_len, terminal_prob, dataset_fraction, is_deterministic):
        assert 0 < terminal_prob and terminal_prob <= 1.0, "0 < terminal_prob <= 1.0"
        assert 0 < dataset_fraction and dataset_fraction <= 1.0, "0 < dataset_fraction <= 1.0"

        self.dataset = None
        self.dataset_path = dataset_path
        self.sequence_len = sequence_len
        self.terminal_prob = terminal_prob
        self.dataset_fraction = dataset_fraction
        self.is_deterministic = is_deterministic

        # https://stackoverflow.com/questions/46045512/h5py-hdf5-database-randomly-returning-nans-and-near-very-small-data-with-multi
        with h5py.File(self.dataset_path, "r") as dataset:
            self.latent_dim = dataset.attrs["LATENT_DIM"]
            self.action_dim = dataset.attrs["ACTION_DIM"]
            self.n_games = dataset.attrs["N_GAMES"]

    def __getitem__(self, idx):
        """Get sequence at random starting position of given sequence length from episode `idx`."""

        offset = 1

        if self.dataset is None:
            self.dataset = h5py.File(self.dataset_path, "r")

        t_start, t_end = self.dataset['episodes'][idx:idx + 2]
        episode_length = t_end - t_start
        if self.sequence_len <= episode_length - offset:
            sequence_len = self.sequence_len
        else:
            sequence_len = episode_length - offset
            # log.info(
            #     "Episode %d is too short to form full sequence, data will be zero-padded.", idx)

        # Sample where to start sequence of length `self.sequence_len` in episode `idx`
        # '- offset' because "next states" are offset by 'offset'
        if np.random.rand() < self.terminal_prob:
            # Take sequence ending with terminal state
            start = t_start + episode_length - sequence_len - offset
        else:
            # NOTE: np.random.randint takes EXCLUSIVE upper bound of range to sample from
            start = t_start + np.random.randint(max(1, episode_length - sequence_len - offset))

        states_ = torch.from_numpy(self.dataset['states'][start:start + sequence_len + offset])
        actions_ = torch.from_numpy(self.dataset['actions'][start:start + sequence_len])

        states = torch.zeros(self.sequence_len, self.latent_dim, dtype=states_.dtype)
        next_states = torch.zeros(self.sequence_len, self.latent_dim, dtype=states_.dtype)
        actions = torch.zeros(self.sequence_len, self.action_dim, dtype=actions_.dtype)

        # Sample latent states (this is done to prevent overfitting of memory to a specific 'z'.)
        if self.is_deterministic:
            z_samples = states_[:, 0]
        else:
            mu = states_[:, 0]
            sigma = torch.exp(states_[:, 1] / 2)
            latent = Normal(loc=mu, scale=sigma)
            z_samples = latent.sample()

        states[:sequence_len] = z_samples[:-offset]
        next_states[:sequence_len] = z_samples[offset:]
        actions[:sequence_len] = actions_

        return [states, actions], [next_states]

    def __len__(self):
        return int(self.n_games * self.dataset_fraction)
1 Like

I am not sure is this solves the concurrency issue since all you do is defer the hdf5 file access from the ctor to the getitem method while ensuring with a singleton-style syntax that the file is read only the first time the getitem method is called (for efficiency). Is that correct? However, the file handle that is used to access the data in each call persists…

I am puzzled

class MyDataset(Dataset):

    def __init__(self, hdf5file):
    
        self.hdf5file = hdf5file
        self.dataset = None

        with h5py.File(self.hdf5file, "r") as dataset:
            self.NrFrms = dataset.attrs['NrFrms']
            self.NrChn = dataset.attrs['NrChn']

    def __len__(self):
        return self.NrFrms * self.NrChn

    def __getitem__(self, idx):

        if self.dataset is None:
            self.dataset = h5py.File(self.hdf5file, "r")

        access_idx = idx % self.NrFrms 
        access_chn = idx // self.NrFrms + 1 

        target = torch.tensor(1, dtype=torch.long) 

        data = torch.tensor(self.dataset['Chn'+str(access_chn)][access_idx], dtype=torch.float32)
        image = torch.tensor(self.dataset['Image'][access_idx], dtype=torch.float32)

        return image, data, target

still yields

RuntimeError: DataLoader worker (pid 17402) exited unexpectedly with exit code 1. Details arelost due to multiprocessing. Rerunning with num_workers=0 may give better error trace.

if more than one worker is used for the dataloading…

It’s correct, I only defer hdf5 file opening to __getitem__ so it’s opened by each worker (not serialised and sent to them). Where do you think it might cause concurrency issues? I only read from the file, I do not make any writes and I assume file isn’t changed by any other process (but it could be done, see: Single Writer Multiple Reader (SWMR) — h5py 3.10.0 documentation).

Did you try rerunning with num_workers=0? Maybe it will give you better error in deed. And you should look at what was printed above the message “RuntimeError: DataLoader worker (pid 17402) exited unexpectedly with exit code 1. Details are lost due to multiprocessing. […]”, there should be call stack of each process, you can see there what happened. Please copy-paste full error, I’ll give it a look.

I am happy to

Blockquote
Traceback (most recent call last):
File “”, line 1, in
File “/home/ditzel/anaconda3/envs/py37/lib/python3.6/multiprocessing/spawn.py”, line 105, in spawn_main
exitcode = _main(fd)
File “/home/ditzel/anaconda3/envs/py37/lib/python3.6/multiprocessing/spawn.py”, line 114, in _main
prepare(preparation_data)
File “/home/ditzel/anaconda3/envs/py37/lib/python3.6/multiprocessing/spawn.py”, line 225, in prepare
_fixup_main_from_path(data[‘init_main_from_path’])
File “/home/ditzel/anaconda3/envs/py37/lib/python3.6/multiprocessing/spawn.py”, line 277, in _fixup_main_from_path
run_name=“mp_main”)
File “/home/ditzel/anaconda3/envs/py37/lib/python3.6/runpy.py”, line 263, in run_path
pkg_name=pkg_name, script_name=fname)
File “/home/ditzel/anaconda3/envs/py37/lib/python3.6/runpy.py”, line 96, in _run_module_code
mod_name, mod_spec, pkg_name, script_name)
File “/home/ditzel/anaconda3/envs/py37/lib/python3.6/runpy.py”, line 85, in _run_code
exec(code, run_globals)
File “/home/ditzel/…/racam/main_racam.py”, line 3, in
torch.multiprocessing.set_start_method(‘spawn’)
File “/home/ditzel/anaconda3/envs/py37/lib/python3.6/multiprocessing/context.py”, line 242,in set_start_method
raise RuntimeError(‘context has already been set’)
RuntimeError: context has already been set
Traceback (most recent call last):
File “/home/ditzel//main_racam.py”, line 136, in
train(epoch)
File “/home/ditzel/main_racam.py”, line 48, in train
for batch_idx, (img, rdm, t) in enumerate(trainloader):
File “/home/ditzel/anaconda3/envs/py37/lib/python3.6/site-packages/torch/utils/data/dataloader.py”, line 631, in next
idx, batch = self._get_batch()
File “/home/ditzel/anaconda3/envs/py37/lib/python3.6/site-packages/torch/utils/data/dataloader.py”, line 610, in _get_batch
return self.data_queue.get()
File “/home/ditzel/anaconda3/envs/py37/lib/python3.6/multiprocessing/queues.py”, line 94, in get
res = self._recv_bytes()
File “/home/ditzel/anaconda3/envs/py37/lib/python3.6/multiprocessing/connection.py”, line 216, in recv_bytes
buf = self._recv_bytes(maxlength)
File “/home/ditzel/anaconda3/envs/py37/lib/python3.6/multiprocessing/connection.py”, line 407, in _recv_bytes
buf = self._recv(4)
File “/home/ditzel/anaconda3/envs/py37/lib/python3.6/multiprocessing/connection.py”, line 379, in _recv
chunk = read(handle, remaining)
File “/home/ditzel/anaconda3/envs/py37/lib/python3.6/site-packages/torch/utils/data/dataloader.py”, line 274, in handler
_error_if_any_worker_fails()
RuntimeError: DataLoader worker (pid 10890) exited unexpectedly with exit code 1. Details arelost due to multiprocessing. Rerunning with num_workers=0 may give better error trace.

when I increase number of workers to 2. If set to 0, unfortunately there is no error and the program works as expected

Sorry for the delay. You need to protect torch.multiprocessing.set_start_method(‘spawn’) in /home/ditzel/radar/radarcamerafusion/racam/main_racam.py:3 with __name__ == "__main__". You can call this function only once, but currently you call it in each thread. See this issue for more details: RuntimeError: context has already been set(multiprocessing) · Issue #3492 · pytorch/pytorch · GitHub.

Also, you can try to delete this line, if you have HDF5 1.10. it shouldn’t be needed.

After I installed HDF5 1.10 from source, how do we change the default HDF5 version on system (I’m using Ubuntu 16.04) to 1.10. I run h5cc -showconfig and it still says hdf5 1.8. Could you give me some instructions?

Sorry, I didn’t do this so I can’t help. You need to Google it :wink: I think you’ll need to change ‘LD_LIBRARY_PATH’.

Wysłane z iPhone’a

Thank you for your help. Problem solved already. :smile:

Hi, need we set the swmr state when open the HDF5 file as:

with h5py.File(self.h5file_path, 'r', swmr=True) as file:
            self.keys = list(file['dataset'].keys())

ile = h5py.File(self.h5file_path, 'r', swmr=True)