Sudden memory spike during distributed training

Hi,

I have been experiencing memory spikes during a distributed training stage and I have been scrambling the Internet for a fix. Specifically the memory consumption stays constant after started training for a while (variable depending on num_workers, a matter of hours when 32, a couple of days when 16, and around one day when 8 for each of the 8 processes for our 8 GPUs with 128 CPU cores and 256 threads), and spikes super fast to consume all of the 1TB of RAM that we have on the machine and the oom-reaper just reaps the process.

We have made sure that we do not use any python list in the dataset, and only use np.string_ when it is string per this issue, and we use the default file_descriptor sharing strategy, set ulimit to unlimited, and set resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1])) as well as torch.multiprocessing.set_start_method('spawn'). We also use gc.collect each 1000 steps, but this does not seem to alleviate the issue.

We tried Ctrl+C when the memory was spiking and it showed the stacks that look like the dataloader’s problem, but we were unable to capture the output. Can anyone help us with this?

I should probably mention that our DataLoader is set as:

DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=shuffle,
            num_workers=num_workers,
            sampler=sampler,
            persistent_workers=num_workers > 0,
            collate_fn=default_collate,
            pin_memory=True,
            prefetch_factor=2
            drop_last=True
)

But it does not seem to help if any of these above changes.

Also that the number of processes stays constant during the spike (the same as when memory consumption stays still)

Quite an interesting problem. Are you seeing any memory increase before the sudden spike occurs?
Based on your debugging so far it seems your suspect would be the Dataset or DataLoader. Could you replace the Dataset with a TensorDataset containing static tensors (the model wouldn’t train, but it would be interesting to see if the same behavior is seen). In case the memory doesn’t increase, could you share the Dataset implementation?

Hi @ptrblck thanks! We have not seen memory-leak-like behavior before the spike, and we were focusing on Dataset and DataLoader since the variable time that it runs before it fails. I can provide the implementation below, but it may take a while before I’m seeing any results from the "TensorDataset with static tensors"

class LargeSequenceSet(Dataset):

    def __init__(
            self, processed_fn: str, transform=None, indices: Iterable = None
    ):
        super(LargeSequenceSet, self).__init__()
        self.processed_file = os.path.expanduser(processed_fn)
        self.transform = transform
        self.db = None
        self.keys = None
        self.indices = np.asarray(indices, dtype=np.int)
        self._len = None

    def _connect_db(self):
        """
            Establish read-only database connection
        """
        assert self.db is None, 'A connection has already been opened.'
        self.db = lmdb.open(
            self.processed_file,
            max_readers=4096,
            map_size=107374182400,
            readonly=True
        )
        with self.db.begin() as txn:
            keys = list(txn.cursor().iternext(values=False))
        if self.indices is not None:
            keys = list_index_select(keys, self.indices)
        self.keys = np.asarray(keys, dtype=np.string_)
        self._len = len(self.keys)

    def _close_db(self):
        self.db.close()
        self.db = None
        self.keys = None

    def __len__(self):
        if self._len is None:
            self._connect_db()
            self._close_db()
        return self._len

    def __getitem__(self, index):
        if self.db is None:
            self._connect_db()
        key = self.keys[index]
        data = self.db.begin().get(key).decode("ascii")
        while any([a not in vocab for a in data]):
            index = random.randint(0, len(self.keys) - 1)
            key = self.keys[index]
            data = self.db.begin().get(key).decode("ascii")

        if self.transform is not None:
            data = self.transform(data)
        return data

Hi, @ptrblck I would like to share some recent discoveries that we have on this issue.

We tried to follow gdb - Python script terminated by SIGKILL rather than throwing MemoryError - Stack Overflow to make the process report a python stack trace. However the memory spike went away. We tried several machines on our freshly installed ones, and the same happened: there’s no spiking when oom reaper is disabled, and spikes happen with oom reaper left untouched per default settings.

We are not yet sure how to interpret this result, but we would like to share it.

Thanks for sharing this update. I haven’t seen such behavior before, but maybe @cbalioglu might have any idea what might be causing this issue.

cc @VitalyFedyunin for data loader questions

@Rui_Wang I know it’s been a long time since you encountered this problem, but do you remember how you you disabled the OOM Reaper?

Did you just set the following in /etc/sysctl.conf?

vm.overcommit_memory = 2
vm.overcommit_ratio = 100

Any ideas how to do this, when running training within docker containers? We’re facing the exact same issue - where memory spikes randomly several hours into training.

Hi, I wonder if you have figured out how to solve this problem. I am facing the same problem that the memory increases suddenly after a period stable training (20 hours - 30 hours). If you solve problem by modifying sysctl.conf?

Hi, I wonder if you have figured out how to solve this problem. I am facing the same problem that the memory increases suddenly after a period stable training (20 hours - 30 hours). If you solve problem by modifying sysctl.conf?

Yeah I followed the same config in the link. I do not use docker at this point so I really have no idea, sorry.

Yeah just follow the link above.