I have been experiencing memory spikes during a distributed training stage and I have been scrambling the Internet for a fix. Specifically the memory consumption stays constant after started training for a while (variable depending on num_workers, a matter of hours when 32, a couple of days when 16, and around one day when 8 for each of the 8 processes for our 8 GPUs with 128 CPU cores and 256 threads), and spikes super fast to consume all of the 1TB of RAM that we have on the machine and the oom-reaper just reaps the process.
We have made sure that we do not use any python list in the dataset, and only use np.string_ when it is string per this issue, and we use the default file_descriptor sharing strategy, set ulimit to unlimited, and set resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1])) as well as torch.multiprocessing.set_start_method('spawn'). We also use gc.collect each 1000 steps, but this does not seem to alleviate the issue.
We tried Ctrl+C when the memory was spiking and it showed the stacks that look like the dataloader’s problem, but we were unable to capture the output. Can anyone help us with this?
Quite an interesting problem. Are you seeing any memory increase before the sudden spike occurs?
Based on your debugging so far it seems your suspect would be the Dataset or DataLoader. Could you replace the Dataset with a TensorDataset containing static tensors (the model wouldn’t train, but it would be interesting to see if the same behavior is seen). In case the memory doesn’t increase, could you share the Dataset implementation?
Hi @ptrblck thanks! We have not seen memory-leak-like behavior before the spike, and we were focusing on Dataset and DataLoader since the variable time that it runs before it fails. I can provide the implementation below, but it may take a while before I’m seeing any results from the "TensorDataset with static tensors"
class LargeSequenceSet(Dataset):
def __init__(
self, processed_fn: str, transform=None, indices: Iterable = None
):
super(LargeSequenceSet, self).__init__()
self.processed_file = os.path.expanduser(processed_fn)
self.transform = transform
self.db = None
self.keys = None
self.indices = np.asarray(indices, dtype=np.int)
self._len = None
def _connect_db(self):
"""
Establish read-only database connection
"""
assert self.db is None, 'A connection has already been opened.'
self.db = lmdb.open(
self.processed_file,
max_readers=4096,
map_size=107374182400,
readonly=True
)
with self.db.begin() as txn:
keys = list(txn.cursor().iternext(values=False))
if self.indices is not None:
keys = list_index_select(keys, self.indices)
self.keys = np.asarray(keys, dtype=np.string_)
self._len = len(self.keys)
def _close_db(self):
self.db.close()
self.db = None
self.keys = None
def __len__(self):
if self._len is None:
self._connect_db()
self._close_db()
return self._len
def __getitem__(self, index):
if self.db is None:
self._connect_db()
key = self.keys[index]
data = self.db.begin().get(key).decode("ascii")
while any([a not in vocab for a in data]):
index = random.randint(0, len(self.keys) - 1)
key = self.keys[index]
data = self.db.begin().get(key).decode("ascii")
if self.transform is not None:
data = self.transform(data)
return data
Hi, @ptrblck I would like to share some recent discoveries that we have on this issue.
We tried to follow gdb - Python script terminated by SIGKILL rather than throwing MemoryError - Stack Overflow to make the process report a python stack trace. However the memory spike went away. We tried several machines on our freshly installed ones, and the same happened: there’s no spiking when oom reaper is disabled, and spikes happen with oom reaper left untouched per default settings.
We are not yet sure how to interpret this result, but we would like to share it.
Any ideas how to do this, when running training within docker containers? We’re facing the exact same issue - where memory spikes randomly several hours into training.
Hi, I wonder if you have figured out how to solve this problem. I am facing the same problem that the memory increases suddenly after a period stable training (20 hours - 30 hours). If you solve problem by modifying sysctl.conf?
Hi, I wonder if you have figured out how to solve this problem. I am facing the same problem that the memory increases suddenly after a period stable training (20 hours - 30 hours). If you solve problem by modifying sysctl.conf?