Working with Dataset that doesn't fit in GPU memory

Hi!

I’ve read the entire discussion on Loading huge data functionality could not find a solution that fits to my case. I’m training a model for NLU, and I have a huge JSONL file that must be used as input.

I’ve split the dataset into multiple files (using pickle to save some disk), each one corresponding to a chunk of data, and created a custom Dataset class to load it. The class saves a list of files with all data chunks and loads a certain chunk in memory only when __getitem__ is called with an index that belongs to that chunk. The idea is to always keep only a single chunk of data in memory.

My custom Dataset class has a __getitem__ that does the following:

  1. Use the idx to get which file to load it in memory
  2. Clear the current data chunk (tensors) loaded into GPU memory and loads the desired chunk in GPU memory, the one that contains the desired data idx.
  3. Return the item in that index

Unfortunately, that was not enough to solve the problem. I’m still seeing a CUDA error: out of memory message (even for small chunks) and I don’t understand why this is happening.

MyDataset(Dataset):
    def __init__(
        self,
        data_files: list = None,
    ):

    self.map_file_to_chunk_idx(data_files)

    def __len__(self):
        return self.num_examples

    def __getitem__(self, idx):
        idx = self._get_idx_in_data_chunk(idx)

        item = {key: val[idx] for key, val in self.data_chunk.items()}
        return item

    def _get_idx_in_data_chunk(self, idx):
        desired_data_chunk_idx = int(idx / self.data_chunk_size)
        device = get_device()
        if (
            self._get_data_chunk_idx_current_loaded_on_device(device)
            != desired_data_chunk_idx
        ):
            self._load_data_chunk_into_device(
                chunk_idx=desired_data_chunk_idx, device=device
            )

        # Update idx to be the idx in the current chunk
        idx = idx % self.data_chunk_size

        return idx

    def _load_data_chunk_into_device(self, chunk_idx=0, device="cpu"):
        if self._chunk_idx_current_loaded_on_device(device) == chunk_idx:
            return

        del self.data_chunk

        file_path = self.chunk_idx_to_file[chunk_idx]

        data_chunk = pickle.load(file)
        self.data_chunk = dict()
        for k, v in data_chunk.items():
             self.data_chunk[k] = torch.tensor(v, device=device)
02:10:47 [1,0]<stderr>:  File "/azureml-envs/azureml_4624ee7bd86de7366e1c13ba9aa23551/lib/python3.8/multiprocessing/spawn.py", line 116, in spawn_main
[1,0]<stderr>:    exitcode = _main(fd, parent_sentinel)
[1,0]<stderr>:  File "/azureml-envs/azureml_4624ee7bd86de7366e1c13ba9aa23551/lib/python3.8/multiprocessing/spawn.py", line 126, in _main
[1,0]<stderr>:    self = reduction.pickle.load(from_parent)
[1,0]<stderr>:  File "/azureml-envs/azureml_4624ee7bd86de7366e1c13ba9aa23551/lib/python3.8/site-packages/torch/multiprocessing/reductions.py", line 111, in rebuild_cuda_tensor
<stderr>:    storage = storage_cls._new_shared_cuda(
<stderr>:RuntimeError: [1,0]<stderr>:CUDA error: out of memory
<stderr>:CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
[1,0]<stderr>:For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
[1,0]<stdout>:Cleaning up environment variables[1,0]<stdout>:
 [1,0]<stderr>:Traceback (most recent call last):
 [1,0]<stderr>:  File "/azureml-envs/azureml_4624ee7bd86de7366e1c13ba9aa23551/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 990, in _try_get_data
[1,0]<stderr>:    data = self._data_queue.get(timeout=timeout)
[1,0]<stderr>:  File "/azureml-envs/azureml_4624ee7bd86de7366e1c13ba9aa23551/lib/python3.8/multiprocessing/queues.py", line 107, in get
1,0]<stderr>:    if not self._poll(timeout):
 [1,0]<stderr>:  File "/azureml-envs/azureml_4624ee7bd86de7366e1c13ba9aa23551/lib/python3.8/multiprocessing/connection.py", line 257, in poll
 [1,0]<stderr>:    return self._poll(timeout)
 [1,0]<stderr>:  File "/azureml-envs/azureml_4624ee7bd86de7366e1c13ba9aa23551/lib/python3.8/multiprocessing/connection.py", line 424, in _poll
 [1,0]<stderr>:    r = wait([self], timeout)
 [1,0]<stderr>:  File "/azureml-envs/azureml_4624ee7bd86de7366e1c13ba9aa23551/lib/python3.8/multiprocessing/connection.py", line 931, in wait
 [1,0]<stderr>:    ready = selector.select(timeout)
 [1,0]<stderr>:  File "/azureml-envs/azureml_4624ee7bd86de7366e1c13ba9aa23551/lib/python3.8/selectors.py", line 415, in select
stderr>:    fd_event_list = self._selector.poll(timeout)
<stderr>:  File "/azureml-envs/azureml_4624ee7bd86de7366e1c13ba9aa23551/lib/python3.8/site-packages/torch/utils/data/_utils/signal_handling.py", line 66, in handler
[1,0]<stderr>:    _error_if_any_worker_fails()
[1,0]<stderr>:RuntimeError: DataLoader worker (pid 11254) exited unexpectedly with exit code 1. Details are lost due to multiprocessing. Rerunning with num_workers=0 may give better error trace.
[1,0]<stderr>:
 [1,0]<stderr>:The above exception was the direct cause of the following exception:
[1,0]<stderr>:
 [1,0]<stderr>:Traceback (most recent call last):
 [1,0]<stderr>:  File "/azureml-envs/azureml_4624ee7bd86de7366e1c13ba9aa23551/lib/python3.8/runpy.py", line 194, in _run_module_as_main
 [1,0]<stderr>:    return _run_code(code, main_globals, None,
 [1,0]<stderr>:  File "/azureml-envs/azureml_4624ee7bd86de7366e1c13ba9aa23551/lib/python3.8/runpy.py", line 87, in _run_code
[1,0]<stderr>:    exec(code, run_globals)
[1,0]<stderr>:  File "/workspaceblobstore/azureml/f5dc6cae-9033-4ad4-955b-057973ec9fc7/dldistillery/run_distillation.py", line 286, in <module>
 [1,0]<stderr>:    main()
 [1,0]<stderr>:  File "/workspaceblobstore/azureml/f5dc6cae-9033-4ad4-955b-057973ec9fc7/dldistillery/run_distillation.py", line 275, in main
 [1,0]<stderr>:    trainer.fit(distillation_module)
 [1,0]<stderr>:  File "/azureml-envs/azureml_4624ee7bd86de7366e1c13ba9aa23551/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 771, in fit
[1,0]<stderr>:    self._call_and_handle_interrupt(
[1,0]<stderr>:  File "/azureml-envs/azureml_4624ee7bd86de7366e1c13ba9aa23551/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 724, in _call_and_handle_interrupt
 [1,0]<stderr>:    return trainer_fn(*args, **kwargs)
 [1,0]<stderr>:  File "/azureml-envs/azureml_4624ee7bd86de7366e1c13ba9aa23551/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 812, in _fit_impl
[1,0]<stderr>:    results = self._run(model, ckpt_path=self.ckpt_path)
[1,0]<stderr>:  File "/azureml-envs/azureml_4624ee7bd86de7366e1c13ba9aa23551/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1237, in _run
 [1,0]<stderr>:    results = self._run_stage()
 [1,0]<stderr>:  File "/azureml-envs/azureml_4624ee7bd86de7366e1c13ba9aa23551/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1324, in _run_stage
 [1,0]<stderr>:    return self._run_train()
[1,0]<stderr>:  File "/azureml-envs/azureml_4624ee7bd86de7366e1c13ba9aa23551/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1346, in _run_train
 [1,0]<stderr>:    self._run_sanity_check()
 [1,0]<stderr>:  File "/azureml-envs/azureml_4624ee7bd86de7366e1c13ba9aa23551/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1414, in _run_sanity_check
[1,0]<stderr>:    val_loop.run()
 [1,0]<stderr>:  File "/azureml-envs/azureml_4624ee7bd86de7366e1c13ba9aa23551/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 204, in run
 [1,0]<stderr>:    self.advance(*args, **kwargs)
 [1,0]<stderr>:  File "/azureml-envs/azureml_4624ee7bd86de7366e1c13ba9aa23551/lib/python3.8/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 153, in advance
 [1,0]<stderr>:    dl_outputs = self.epoch_loop.run(self._data_fetcher, dl_max_batches, kwargs)
 [1,0]<stderr>:  File "/azureml-envs/azureml_4624ee7bd86de7366e1c13ba9aa23551/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 204, in run
 [1,0]<stderr>:    self.advance(*args, **kwargs)
[1,0]<stderr>:  File "/azureml-envs/azureml_4624ee7bd86de7366e1c13ba9aa23551/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 111, in advance
 [1,0]<stderr>:    batch = next(data_fetcher)
[1,0]<stderr>:  File "/azureml-envs/azureml_4624ee7bd86de7366e1c13ba9aa23551/lib/python3.8/site-packages/pytorch_lightning/utilities/fetching.py", line 184, in __next__
 [1,0]<stderr>:    return self.fetching_function()
 [1,0]<stderr>:  File "/azureml-envs/azureml_4624ee7bd86de7366e1c13ba9aa23551/lib/python3.8/site-packages/pytorch_lightning/utilities/fetching.py", line 259, in fetching_function
 [1,0]<stderr>:    self._fetch_next_batch(self.dataloader_iter)
 [1,0]<stderr>:  File "/azureml-envs/azureml_4624ee7bd86de7366e1c13ba9aa23551/lib/python3.8/site-packages/pytorch_lightning/utilities/fetching.py", line 273, in _fetch_next_batch
 [1,0]<stderr>:    batch = next(iterator)
[1,0]<stderr>:  File "/azureml-envs/azureml_4624ee7bd86de7366e1c13ba9aa23551/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 521, in __next__
[1,0]<stderr>:    data = self._next_data()
 [1,0]<stderr>:  File "/azureml-envs/azureml_4624ee7bd86de7366e1c13ba9aa23551/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1186, in _next_data
 [1,0]<stderr>:    idx, data = self._get_data()
 [1,0]<stderr>:  File "/azureml-envs/azureml_4624ee7bd86de7366e1c13ba9aa23551/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1152, in _get_data
<stderr>:    success, data = self._try_get_data()
<stderr>:  File "/azureml-envs/azureml_4624ee7bd86de7366e1c13ba9aa23551/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1003, in _try_get_data
[1,0]<stderr>:    raise RuntimeError('DataLoader worker (pid(s) {}) exited unexpectedly'.format(pids_str)) from e
[1,0]<stderr>:RuntimeError[1,0]<stderr>:: DataLoader worker (pid(s) 11254) exited unexpectedly

My impression is that memory in GPU keeps increasing over time instead of a single chunk is always kept. Could you help me understand what I’m missing here? What would be the best way to understand how the memory is allocated on GPU?

To separate other potential issues, is the OOM reproducible e.g., if an empty tensor is loaded each time rather than loading data from disk via pickle.load? If that is the case I would check if there are any other variables that might be holding on to the returned data or if there could be a potential leak occurring in a loop.