Hi,
I am building a sort of memory network and I’m facing some RAM problems (likely a memory leak). Unfortunately I’m not sure of the cause despite investigating the past three days. I have something I have noticed and some code. This is causing the model to train for about 3 to 4 epochs before throwing an OOM.
Because I call the function that keeps the memory size fixed every epoch, the RAM usage of the memory component should be fixed or bounded but it keeps increasing throughout the training.
- The RAM usage is constant if nothing is added to the memory (no problem in the rest of the architecture/training loop)
- If I do use the memory component and add to its memory for only part of the training the GPU usage still increases throughout the whole training (less steeply when nothing is added)
- I am detaching/no grading as much as possible.
I have also tried to del
the memory entirely, then recreate it, and just put back in the weights and memory cells. Didn’t work.
Now for some code:
I have a Memory Module with several components, init:
self.memory_cells = torch.zeros(0, XXX, requires_grad=False).cuda()
self.memory_values = torch.zeros(0, YYY, requires_grad=False).long().cuda()
self.memory_hashes = defaultdict(bool)
self.memory_counters = torch.zeros(0, 1, requires_grad=False).long().cuda()
I also need to add to the memory, this adds 5 items per batch:
def _add_to_memory(self, elements, values, indexing_values):
with torch.no_grad():
highest_idxs = indexing_values.argsort(dim=0, descending=True)
elements = elements[highest_idxs ]
values = values[highest_idxs ]
to_add = []
for j, element in enumerate(elements):
if not self.memory_hashes[element]:
to_add.append(j)
self.memory_hashes[element] = True
if len(to_add) == 5:
break
new_counters = torch.zeros(len(to_add), 1, requires_grad=False).long().cuda()
new_cells = elements[to_add]
new_vals = values[to_add]
self.memory_cells = torch.cat((self.memory_cells, new_cells), dim=0)
self.memory_values = torch.cat((self.memory_values.view(-1, YYY), new_vals), dim=0)
self.memory_counters = torch.cat((self.memory_counters.view(-1, 1), new_counters), dim=0)
del highest_losses, elements, values, new_cells, new_vals, new_counters
The function above is called every batch as:
with torch.no_grad():
values = torch.zeros(batch_size * max_len, self.num_labels, dtype=torch.long).cuda()
values.scatter_(1, labels.view(-1, 1),
torch.ones(batch_size * max_len, dtype=torch.long).cuda().view(-1, 1))
self.memory._add_to_memory(sequence_output.detach().view(-1, feat_dim), values.detach(),
loss_expanded.detach().view(-1), max_loss=True)
del values
Every once in a while I want to keep only some of the memory entries, those that are used the most:
def keep_best_memory(self, max_size):
if self.memory_counters.shape[0] != 0:
with torch.no_grad():
idxs = self.memory_counters.view(-1).argsort(descending=True)[
:min(max_size, self.memory_counters.shape[0])]
new_cells = self.memory_cells[idxs].detach().clone()
new_vals = self.memory_values[idxs].detach().clone()
new_counters = torch.zeros(idxs.shape[0], 1, requires_grad=False, dtype=torch.long).cuda()
del self.memory_cells, self.memory_values, self.memory_counters
self.memory_cells = new_cells.clone()
self.memory_values = new_vals.clone()
self.memory_counters = new_counters.clone()
del new_cells, new_vals, new_counters
Finally the forward pass of the memory component. The FF layers are defined elsewhere but I doubt they are the problem.
def forward(self, x, *args):
batch, sent_len, feats = x.shape
x = x.view(-1, feats)
query = self.query_fc(x)
keys = self.key_fc(self.memory_cells)
sims = torch.matmul(query, keys.t()) # dot similariy
idxs = sims.sort(dim=1, descending=True).indices
k_highest_sims = smart_sort(sims, idxs)[:, :self.top_k_similarities] # multidimensional sorting, not relevant for the question
classifications = self.memory_values[idxs[:, :self.top_k_similarities]]
self.memory_counters[idxs[:, :self.top_k_similarities]] += 1
z = self.memory_gate(k_highest_sims.view(-1, self.top_k_similarities)).view(-1, 1).expand(-1, self.num_classes)
softmaxed_sims = k_highest_sims.view(-1, self.top_k_similarities, 1).expand(-1, self.top_k_similarities,
self.num_classes)
logits_memory = (classifications * softmaxed_sims).sum(dim=1)
return logits_memory.view(batch, sent_len, self.num_classes), z.view(batch, sent_len,
self.num_classes), k_highest_sims, classifications
And in the main network the memory is used like so, where logits_network
is taken from another part of the networ. Using only those logits causes no problem at all.
if self.memory.is_ready():
logits_memory, z, _, _ = self.memory(sequence_output, logits_network, similarity='angular')
logits = (1 - z) * logits_network + z * logits_memory
else:
logits_memory = None
logits = logits_network
Memory usage:
I understand the problem is pretty vague but I’m really out of ideas, any help would be appreciated.