thank you, your answers are very detailed and your contributions to this community are invaluable =)
I wonder, how this generalizes to the more common case in which there exist multiple data sets within the hdf5 file that needs to be accessed by the index. Is code duplication and loading of the respective sets in torch.tensors the way to go here?
Here is an example of non-trivial Dataset that I use for preprocessing data for World Modelsâ memory module training:
class MemoryDataset(Dataset):
"""Dataset of sequential data to train memory.
Args:
dataset_path (string): Path to HDF5 dataset file.
sequence_len (int): Desired output sequence len.
terminal_prob (float): Probability of sampling sequence that finishes with
terminal state.
dataset_fraction (float): Fraction of dataset to use during training, value range: (0, 1]
(dataset forepart is taken).
is_deterministic (bool): If return sampled latent states or mean latent states.
Note:
Arrays should have the same size of the first dimension and their type should be the
same as desired Tensor type.
"""
def __init__(self, dataset_path, sequence_len, terminal_prob, dataset_fraction, is_deterministic):
assert 0 < terminal_prob and terminal_prob <= 1.0, "0 < terminal_prob <= 1.0"
assert 0 < dataset_fraction and dataset_fraction <= 1.0, "0 < dataset_fraction <= 1.0"
self.dataset = None
self.dataset_path = dataset_path
self.sequence_len = sequence_len
self.terminal_prob = terminal_prob
self.dataset_fraction = dataset_fraction
self.is_deterministic = is_deterministic
# https://stackoverflow.com/questions/46045512/h5py-hdf5-database-randomly-returning-nans-and-near-very-small-data-with-multi
with h5py.File(self.dataset_path, "r") as dataset:
self.latent_dim = dataset.attrs["LATENT_DIM"]
self.action_dim = dataset.attrs["ACTION_DIM"]
self.n_games = dataset.attrs["N_GAMES"]
def __getitem__(self, idx):
"""Get sequence at random starting position of given sequence length from episode `idx`."""
offset = 1
if self.dataset is None:
self.dataset = h5py.File(self.dataset_path, "r")
t_start, t_end = self.dataset['episodes'][idx:idx + 2]
episode_length = t_end - t_start
if self.sequence_len <= episode_length - offset:
sequence_len = self.sequence_len
else:
sequence_len = episode_length - offset
# log.info(
# "Episode %d is too short to form full sequence, data will be zero-padded.", idx)
# Sample where to start sequence of length `self.sequence_len` in episode `idx`
# '- offset' because "next states" are offset by 'offset'
if np.random.rand() < self.terminal_prob:
# Take sequence ending with terminal state
start = t_start + episode_length - sequence_len - offset
else:
# NOTE: np.random.randint takes EXCLUSIVE upper bound of range to sample from
start = t_start + np.random.randint(max(1, episode_length - sequence_len - offset))
states_ = torch.from_numpy(self.dataset['states'][start:start + sequence_len + offset])
actions_ = torch.from_numpy(self.dataset['actions'][start:start + sequence_len])
states = torch.zeros(self.sequence_len, self.latent_dim, dtype=states_.dtype)
next_states = torch.zeros(self.sequence_len, self.latent_dim, dtype=states_.dtype)
actions = torch.zeros(self.sequence_len, self.action_dim, dtype=actions_.dtype)
# Sample latent states (this is done to prevent overfitting of memory to a specific 'z'.)
if self.is_deterministic:
z_samples = states_[:, 0]
else:
mu = states_[:, 0]
sigma = torch.exp(states_[:, 1] / 2)
latent = Normal(loc=mu, scale=sigma)
z_samples = latent.sample()
states[:sequence_len] = z_samples[:-offset]
next_states[:sequence_len] = z_samples[offset:]
actions[:sequence_len] = actions_
return [states, actions], [next_states]
def __len__(self):
return int(self.n_games * self.dataset_fraction)
I am not sure is this solves the concurrency issue since all you do is defer the hdf5 file access from the ctor to the getitem method while ensuring with a singleton-style syntax that the file is read only the first time the getitem method is called (for efficiency). Is that correct? However, the file handle that is used to access the data in each call persistsâŚ
RuntimeError: DataLoader worker (pid 17402) exited unexpectedly with exit code 1. Details arelost due to multiprocessing. Rerunning with num_workers=0 may give better error trace.
if more than one worker is used for the dataloadingâŚ
Itâs correct, I only defer hdf5 file opening to __getitem__ so itâs opened by each worker (not serialised and sent to them). Where do you think it might cause concurrency issues? I only read from the file, I do not make any writes and I assume file isnât changed by any other process (but it could be done, see: Single Writer Multiple Reader (SWMR) â h5py 3.10.0 documentation).
Did you try rerunning with num_workers=0? Maybe it will give you better error in deed. And you should look at what was printed above the message âRuntimeError: DataLoader worker (pid 17402) exited unexpectedly with exit code 1. Details are lost due to multiprocessing. [âŚ]â, there should be call stack of each process, you can see there what happened. Please copy-paste full error, Iâll give it a look.
Blockquote
Traceback (most recent call last):
File ââ, line 1, in
File â/home/ditzel/anaconda3/envs/py37/lib/python3.6/multiprocessing/spawn.pyâ, line 105, in spawn_main
exitcode = _main(fd)
File â/home/ditzel/anaconda3/envs/py37/lib/python3.6/multiprocessing/spawn.pyâ, line 114, in _main
prepare(preparation_data)
File â/home/ditzel/anaconda3/envs/py37/lib/python3.6/multiprocessing/spawn.pyâ, line 225, in prepare
_fixup_main_from_path(data[âinit_main_from_pathâ])
File â/home/ditzel/anaconda3/envs/py37/lib/python3.6/multiprocessing/spawn.pyâ, line 277, in _fixup_main_from_path
run_name=âmp_mainâ)
File â/home/ditzel/anaconda3/envs/py37/lib/python3.6/runpy.pyâ, line 263, in run_path
pkg_name=pkg_name, script_name=fname)
File â/home/ditzel/anaconda3/envs/py37/lib/python3.6/runpy.pyâ, line 96, in _run_module_code
mod_name, mod_spec, pkg_name, script_name)
File â/home/ditzel/anaconda3/envs/py37/lib/python3.6/runpy.pyâ, line 85, in _run_code
exec(code, run_globals)
File â/home/ditzel/âŚ/racam/main_racam.pyâ, line 3, in
torch.multiprocessing.set_start_method(âspawnâ)
File â/home/ditzel/anaconda3/envs/py37/lib/python3.6/multiprocessing/context.pyâ, line 242,in set_start_method
raise RuntimeError(âcontext has already been setâ)
RuntimeError: context has already been set
Traceback (most recent call last):
File â/home/ditzel//main_racam.pyâ, line 136, in
train(epoch)
File â/home/ditzel/main_racam.pyâ, line 48, in train
for batch_idx, (img, rdm, t) in enumerate(trainloader):
File â/home/ditzel/anaconda3/envs/py37/lib/python3.6/site-packages/torch/utils/data/dataloader.pyâ, line 631, in next
idx, batch = self._get_batch()
File â/home/ditzel/anaconda3/envs/py37/lib/python3.6/site-packages/torch/utils/data/dataloader.pyâ, line 610, in _get_batch
return self.data_queue.get()
File â/home/ditzel/anaconda3/envs/py37/lib/python3.6/multiprocessing/queues.pyâ, line 94, in get
res = self._recv_bytes()
File â/home/ditzel/anaconda3/envs/py37/lib/python3.6/multiprocessing/connection.pyâ, line 216, in recv_bytes
buf = self._recv_bytes(maxlength)
File â/home/ditzel/anaconda3/envs/py37/lib/python3.6/multiprocessing/connection.pyâ, line 407, in _recv_bytes
buf = self._recv(4)
File â/home/ditzel/anaconda3/envs/py37/lib/python3.6/multiprocessing/connection.pyâ, line 379, in _recv
chunk = read(handle, remaining)
File â/home/ditzel/anaconda3/envs/py37/lib/python3.6/site-packages/torch/utils/data/dataloader.pyâ, line 274, in handler
_error_if_any_worker_fails()
RuntimeError: DataLoader worker (pid 10890) exited unexpectedly with exit code 1. Details arelost due to multiprocessing. Rerunning with num_workers=0 may give better error trace.
when I increase number of workers to 2. If set to 0, unfortunately there is no error and the program works as expected
After I installed HDF5 1.10 from source, how do we change the default HDF5 version on system (Iâm using Ubuntu 16.04) to 1.10. I run h5cc -showconfig and it still says hdf5 1.8. Could you give me some instructions?
If you donât plan to write to this HDF5 file while itâs being read, then no. If you need SWMR (Single-Writer/Multiple-Reader) then you should enable it. If latter is your use case please read more about it e.g. here.
Thank you so much for your contribution. I was wondering if you have compared performance between having the dataset as a simple folder containing flat files vs the dataset in HDF5?
I did not, but in simple case when you have data stored locally on the machine you use for computation it shouldât yield much difference. In that case my recommendation is: do whatever is easier for you AND THEN in case you see that DataLoader is a bottleneck and your GPU isnât fully utilised, then you might want to try binary format like HDF5 to store data. My motivation was to store all the experience of an agent from Atari game in one file, but now Iâm working with approach where each episode is saved as a numpy archive in the directory and DataLoader loads each one by one and it works just fine (and is MUCH easier to work with, as each episode can have different length)
Hi @piojanu@ptrblck ,
First of all - I would like to thank you for the great research and code. It has increased my training speed by an order of magnitude!
I have used the HDF5 Dataset with a single hdf5 file and it worked great.
However, when I used it together with the torch.utils.data.dataset.ConcatDataset, it has slowed down by significantly. Do you have any idea why this should happen?
If it helps, here is the time profiling of my training. The acquiring of threads takes up most of the time. I have very little knowledge of threading, so this doesnât make too much sense to me. I have 6 workers, so maybe the number 87% has to do with 5/6~83%? Iâm just trying to guess here
After looking this up, I saw a post that suggested to use pin_memory=False. This does not speed up anything, but rather passes the bottleneck elsewhere:
Hi, I followed the advice in this thread but am still stuck with this error:
File "h5py/_objects.pyx", line 184, in h5py._objects.ObjectID.__cinit__
TypeError: __cinit__() takes exactly 1 positional argument (0 given)
I put the file init in the __getitem__ function of the dataset and I set the start method to spawn, but none of these fix the problem. I also tried using h5pickle per this issue, but the error remains. I am on Archlinux with the latest version of h5py (1.12). Here is a snippet of my Dataloader
def __init__(self, file, group, vel_seq, normalization=True):
self.dataset = None
self.len = None
self.file = file
self.group = group
self.vel_seq = vel_seq
self.normalization = normalization
def __getitem__(self, _):
if self.dataset is None:
self.dataset = h5py.File(self.file, 'r')
# do stuff with dataset
Also, Iâve tried adding the swmr=True flag in the File call but that also doesnât do anything. Has anyone had this problem?
My torch version is 1.9 with cuda 11.1 GPU RTX 3090, when I am using number of workers>0 I got this warnings and they are equal to the number of workers , here they are four because I set the number of workers=4
warning: leaking caffe2 thread-pool after fork. (function pthreadpool)
warning: leaking caffe2 thread-pool after fork. (function pthreadpool)
warning: leaking caffe2 thread-pool after fork. (function pthreadpool)
warning: leaking caffe2 thread-pool after fork. (function pthreadpool)
When the code saves the trained model, it produces this error and didnât write it correctly
Thanks for your reply
the code was working fine with torch version 1.1 and cuda 9
when I use torch 1.9 with cuda 11.1 I got this error
the model is not saved correctly and also the accuracy became worse when I putted the pin_memory =False, The accuracy become better when I set the number of workers =0 and the warnings disappeared but the training became too slow
in summary, I donât know how to fix this issue without affecting the accuracy and the speed and writing the model