Sudden memory spike during distributed training

Hi,

I have been experiencing memory spikes during a distributed training stage and I have been scrambling the Internet for a fix. Specifically the memory consumption stays constant after started training for a while (variable depending on num_workers, a matter of hours when 32, a couple of days when 16, and around one day when 8 for each of the 8 processes for our 8 GPUs with 128 CPU cores and 256 threads), and spikes super fast to consume all of the 1TB of RAM that we have on the machine and the oom-reaper just reaps the process.

We have made sure that we do not use any python list in the dataset, and only use np.string_ when it is string per this issue, and we use the default file_descriptor sharing strategy, set ulimit to unlimited, and set resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1])) as well as torch.multiprocessing.set_start_method('spawn'). We also use gc.collect each 1000 steps, but this does not seem to alleviate the issue.

We tried Ctrl+C when the memory was spiking and it showed the stacks that look like the dataloader’s problem, but we were unable to capture the output. Can anyone help us with this?

I should probably mention that our DataLoader is set as:

DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=shuffle,
            num_workers=num_workers,
            sampler=sampler,
            persistent_workers=num_workers > 0,
            collate_fn=default_collate,
            pin_memory=True,
            prefetch_factor=2
            drop_last=True
)

But it does not seem to help if any of these above changes.

Also that the number of processes stays constant during the spike (the same as when memory consumption stays still)

Quite an interesting problem. Are you seeing any memory increase before the sudden spike occurs?
Based on your debugging so far it seems your suspect would be the Dataset or DataLoader. Could you replace the Dataset with a TensorDataset containing static tensors (the model wouldn’t train, but it would be interesting to see if the same behavior is seen). In case the memory doesn’t increase, could you share the Dataset implementation?

Hi @ptrblck thanks! We have not seen memory-leak-like behavior before the spike, and we were focusing on Dataset and DataLoader since the variable time that it runs before it fails. I can provide the implementation below, but it may take a while before I’m seeing any results from the "TensorDataset with static tensors"

class LargeSequenceSet(Dataset):

    def __init__(
            self, processed_fn: str, transform=None, indices: Iterable = None
    ):
        super(LargeSequenceSet, self).__init__()
        self.processed_file = os.path.expanduser(processed_fn)
        self.transform = transform
        self.db = None
        self.keys = None
        self.indices = np.asarray(indices, dtype=np.int)
        self._len = None

    def _connect_db(self):
        """
            Establish read-only database connection
        """
        assert self.db is None, 'A connection has already been opened.'
        self.db = lmdb.open(
            self.processed_file,
            max_readers=4096,
            map_size=107374182400,
            readonly=True
        )
        with self.db.begin() as txn:
            keys = list(txn.cursor().iternext(values=False))
        if self.indices is not None:
            keys = list_index_select(keys, self.indices)
        self.keys = np.asarray(keys, dtype=np.string_)
        self._len = len(self.keys)

    def _close_db(self):
        self.db.close()
        self.db = None
        self.keys = None

    def __len__(self):
        if self._len is None:
            self._connect_db()
            self._close_db()
        return self._len

    def __getitem__(self, index):
        if self.db is None:
            self._connect_db()
        key = self.keys[index]
        data = self.db.begin().get(key).decode("ascii")
        while any([a not in vocab for a in data]):
            index = random.randint(0, len(self.keys) - 1)
            key = self.keys[index]
            data = self.db.begin().get(key).decode("ascii")

        if self.transform is not None:
            data = self.transform(data)
        return data

Hi, @ptrblck I would like to share some recent discoveries that we have on this issue.

We tried to follow gdb - Python script terminated by SIGKILL rather than throwing MemoryError - Stack Overflow to make the process report a python stack trace. However the memory spike went away. We tried several machines on our freshly installed ones, and the same happened: there’s no spiking when oom reaper is disabled, and spikes happen with oom reaper left untouched per default settings.

We are not yet sure how to interpret this result, but we would like to share it.

Thanks for sharing this update. I haven’t seen such behavior before, but maybe @cbalioglu might have any idea what might be causing this issue.

cc @VitalyFedyunin for data loader questions