Dataset not being copied across worker processes

Hi, and apologies if my question is answered somewhere, I couldn’t really find anything because the questions in the opposite direction always show up. Also apologies because I cannot really provide my code (as its probably proprietary).

My loop looks something like this.

dataset = Dataset_Class(args)
cache_ds = deepcopy(dataset)
dataloader = Dataloader(dataset, otherargs)

for e in epochs:
    for step, batch in dataloader:
        cache_ds.cache_batch(batch)
        optimizer.zero_grad()
        # amp block
        loss = model(batch)
        loss.backward()
        optimizer.step()
    # the same thing is done for validation
    if e == 0:
        dataloader = Dataloader(cache_ds, otherargs)

Now, naively, I would think that this will cache all the data within the dataset cache_ds, which lives in the main process’ RAM during the first epoch (which it does). I would also assume that caching data live in this manner (and later using the cache) is faster than via disk (it is) or via the shared dict suggestion found at pytorch_misc/shared_dict.py at master · ptrblck/pytorch_misc · GitHub, as I have no IPC overhead from the dataloader worker processes accessing the shared memory. Every worker should have its own dataset, fully cached, in its own RAM, after one large copying op at the end of epoch 0, and they should also be quite fast.

That last set of assumptions is where I am wrong, apparently, but I don’t understand why. According to free -mh running in a separate shell on the machine I work on, the RAM usage does not increase after the first epoch, despite setting the number of workers to 2 or 4 (this one should run me out of memory ftr).

Why does this not happen? Shouldn’t the dataset be copied to every worker unless I explicitly design the dataset around a shared dict behind a proxy object like in the GitHub link?

Is there maybe some hardcoded limit up to which the dataloader will copy your dataset and past which it decides that it will not copy? The in-memory size of the dataset is around 500 odd GB. Available memory is 2 TB.

Don’t know if any of this is relevant:
I checked whether I was even running more than one process, and I am certain that I am.
Also, because I think I read somewhere that tensors have a sort of shared memory implementation by default, I saved all tensors in the cache as numpy arrays instead of tensors (via tensor.clone().detach().numpy() and then splitting the numpy array along the batch dimension).

Based on the docs I would assume the memory usage should increase in case you are pre-loading a lot of ref-counted objects in the __init__ method:

After several iterations, the loader worker processes will consume the same amount of CPU memory as the parent process for all Python objects in the parent process which are accessed from the worker processes. This can be problematic if the Dataset contains a lot of data (e.g., you are loading a very large list of filenames at Dataset construction time) and/or you are using a lot of workers (overall memory usage is number of workers * size of parent process).

However, I assume cache_ds is now containing all samples and is not lazily loading them anymore. Is this assumption correct?

Yes, that is correct. The dataset knows which of its subsets are cached or uncached and no longer loads lazily once a flag is set (which happens at the end of epoch 0, although I did not explicitly mention it in the original post), loading instead from the cache (a dict containting tuples of numpy arrays, in case it matters).

OK, thanks for the explanation. Would it be possible to upload a small dataset somewhere and share a minimal, executable code snippet to reproduce the issue?

Hopefully yes :smile: I will try to make a MVE and get back to you asap

This is an extremely reduced version of the code I ran. Going into epochs > 0 leaves the memory footprint unchanged. There appears to be only one dataset object despite there being two worker processes. Unless a recent change changed anything, the same holds true for any number of workers really.

import torch
import numpy as np
from copy import deepcopy
import time
from tqdm.auto import tqdm

def itemize_and_cache(data, idx, targets, cache_ds):
    cd = data.clone().detach().numpy()
    ct = targets.clone().detach().numpy()
    cd = np.squeeze(np.split(cd, len(cd), axis=0), axis=1)
    ct = np.squeeze(np.split(ct, len(ct), axis=0), axis=1)
    ci = idx.clone().detach().tolist()
    cache_ds._batch_to_cache(cd, ct, ci)

class CacheError(Exception):
    pass

class Dataset(torch.utils.data.Dataset):
    def __init__(self, size):
        self.cache = {}
        self.caching_complete = False
        self.size = size

    def __len__(self):
        if self.caching_complete is False:
            return self.size
        else:
            return len(self.cache)

    def _batch_to_cache(self, cd, ct, ci):
        for i, idx in enumerate(ci):
            if cd[i] is None or ct[i] is None:
                item = None
            else:
                item = (cd[i], ct[i])
            self.cache[idx] = item

    def _from_cache(self, idx):
        try:
            item = self.cache[idx]
            return (torch.as_tensor(item[0]), torch.as_tensor(item[1]))
        except (KeyError, TypeError): # does not exist OR exists but is empty
            raise CacheError(f"Item {idx} not in cache!")

    def __getitem__(self, idx, layer=0):
        if layer == 5:
            raise RuntimeError("Dataset probably has a broken component, stopping.")
        try:
            if self.caching_complete is True:
                data, target = self._from_cache(idx)
                return data, idx, target
            else:
                # Read from disk here, dummy data
                data = torch.ones(size=[3, 256, 256], dtype=torch.float16)
                target = torch.LongTensor(0)
                return data, idx, target
        except KeyboardInterrupt:
            raise
        except Exception as e:
            print(repr(e))
            return self.__getitem__(idx = np.random.randint(0, len(self), 1), layer=layer+1)

batch_size = int(4096)

dataset = Dataset(int(1.4e6)) # Approximate size of ImageNet-1k
cache_ds = deepcopy(dataset)

dataloader = torch.utils.data.DataLoader(
    dataset = dataset,
    batch_size = batch_size,
    num_workers = 32,
    prefetch_factor = 1,
    shuffle = True,
    drop_last = False,
    persistent_workers = False)

for epoch in range(10):
    if epoch == 0:
        steps = len(dataloader.dataset)//batch_size +1
    else:
        steps = len(dataloader.dataset)//batch_size

    for step, (data, idxs, targets) in tqdm(enumerate(dataloader), total=steps):
        if epoch == 0:
            itemize_and_cache(data, idxs, targets, cache_ds)
        else:
            # model calcs placeholder
            time.sleep(0.1)
            
    if epoch == 0:
        cache_ds.caching_complete = True
        dataloader = torch.utils.data.DataLoader(
            dataset = cache_ds,
            batch_size = batch_size,
            num_workers = 2,
            prefetch_factor = 2,
            shuffle = True,
            drop_last = True,
            persistent_workers = True
        )

Chances are I am missing something really obvious, but at first glance it seems like there is two things happening. 1) There is no copy. 2) The workers can access some manifestation of the cache dictionary from their processes.

I can confirm by printing in python, via shell command or via looking at the tqdm progress bar, which makes 2 quick jumps for the time.sleep “calculation” and then waits on the mock data loading, that there really is multiple worker processes.

Presumably not relevant, but the whole thing runs in pytorch 1.11.0 on a DGX.

Throwing in a print(hex(id(self.cache))) into the __getitem__ method of the Dataset class, and setting a flag in the dataset to print it only if the dataset hadn’t printed before, showed me the same id four times. So they do apparently see the same cache object, right?

(Sidenote: Why four? 2 worker processes x 2 threads?)