Memory usage increasing despite fixed-sized memory-network

Hi,
I am building a sort of memory network and I’m facing some RAM problems (likely a memory leak). Unfortunately I’m not sure of the cause despite investigating the past three days. I have something I have noticed and some code. This is causing the model to train for about 3 to 4 epochs before throwing an OOM.
Because I call the function that keeps the memory size fixed every epoch, the RAM usage of the memory component should be fixed or bounded but it keeps increasing throughout the training.

  • The RAM usage is constant if nothing is added to the memory (no problem in the rest of the architecture/training loop)
  • If I do use the memory component and add to its memory for only part of the training the GPU usage still increases throughout the whole training (less steeply when nothing is added)
  • I am detaching/no grading as much as possible.

I have also tried to del the memory entirely, then recreate it, and just put back in the weights and memory cells. Didn’t work.

Now for some code:

I have a Memory Module with several components, init:

        self.memory_cells = torch.zeros(0, XXX, requires_grad=False).cuda()
        self.memory_values = torch.zeros(0, YYY, requires_grad=False).long().cuda()
        self.memory_hashes = defaultdict(bool)
        self.memory_counters = torch.zeros(0, 1, requires_grad=False).long().cuda()

I also need to add to the memory, this adds 5 items per batch:

    def _add_to_memory(self, elements, values, indexing_values):
        with torch.no_grad():
            highest_idxs = indexing_values.argsort(dim=0, descending=True)
            elements = elements[highest_idxs ]
            values = values[highest_idxs ]

            to_add = []

            for j, element in enumerate(elements):
                if not self.memory_hashes[element]:
                    to_add.append(j)
                    self.memory_hashes[element] = True
                    if len(to_add) == 5:
                        break

            new_counters = torch.zeros(len(to_add), 1, requires_grad=False).long().cuda()
            new_cells = elements[to_add]
            new_vals = values[to_add]

            self.memory_cells = torch.cat((self.memory_cells, new_cells), dim=0)
            self.memory_values = torch.cat((self.memory_values.view(-1, YYY), new_vals), dim=0)
            self.memory_counters = torch.cat((self.memory_counters.view(-1, 1), new_counters), dim=0)

            del highest_losses, elements, values, new_cells, new_vals, new_counters

The function above is called every batch as:

                with torch.no_grad():
                    values = torch.zeros(batch_size * max_len, self.num_labels, dtype=torch.long).cuda()
                    values.scatter_(1, labels.view(-1, 1),
                                    torch.ones(batch_size * max_len, dtype=torch.long).cuda().view(-1, 1))
                    self.memory._add_to_memory(sequence_output.detach().view(-1, feat_dim), values.detach(),
                                               loss_expanded.detach().view(-1), max_loss=True)

                    del values

Every once in a while I want to keep only some of the memory entries, those that are used the most:

    def keep_best_memory(self, max_size):
        if self.memory_counters.shape[0] != 0:
            with torch.no_grad():
                idxs = self.memory_counters.view(-1).argsort(descending=True)[
                       :min(max_size, self.memory_counters.shape[0])]

                new_cells = self.memory_cells[idxs].detach().clone()
                new_vals = self.memory_values[idxs].detach().clone()
                new_counters = torch.zeros(idxs.shape[0], 1, requires_grad=False, dtype=torch.long).cuda()

                del self.memory_cells, self.memory_values, self.memory_counters

                self.memory_cells = new_cells.clone()
                self.memory_values = new_vals.clone()
                self.memory_counters = new_counters.clone()

                del new_cells, new_vals, new_counters

Finally the forward pass of the memory component. The FF layers are defined elsewhere but I doubt they are the problem.

    def forward(self, x, *args):
        batch, sent_len, feats = x.shape
        x = x.view(-1, feats)

        query = self.query_fc(x)
        keys = self.key_fc(self.memory_cells)

        sims = torch.matmul(query, keys.t())  # dot similariy

        idxs = sims.sort(dim=1, descending=True).indices
        k_highest_sims = smart_sort(sims, idxs)[:, :self.top_k_similarities] # multidimensional sorting, not relevant for the question
        classifications = self.memory_values[idxs[:, :self.top_k_similarities]]

        self.memory_counters[idxs[:, :self.top_k_similarities]] += 1

        z = self.memory_gate(k_highest_sims.view(-1, self.top_k_similarities)).view(-1, 1).expand(-1, self.num_classes)

        softmaxed_sims = k_highest_sims.view(-1, self.top_k_similarities, 1).expand(-1, self.top_k_similarities,
                                                                                    self.num_classes)
        logits_memory = (classifications * softmaxed_sims).sum(dim=1)

        return logits_memory.view(batch, sent_len, self.num_classes), z.view(batch, sent_len,
                                                                             self.num_classes), k_highest_sims, classifications

And in the main network the memory is used like so, where logits_network is taken from another part of the networ. Using only those logits causes no problem at all.

        if self.memory.is_ready():
            logits_memory, z, _, _ = self.memory(sequence_output, logits_network, similarity='angular')
            logits = (1 - z) * logits_network + z * logits_memory
        else:
            logits_memory = None
            logits = logits_network

Memory usage:
image

I understand the problem is pretty vague but I’m really out of ideas, any help would be appreciated.