GPU memory accumulation in parallel_apply

My aim is to get a linear layer with large output dimension. To achieve this I store the weights of the linear layer in an embedding layer. Further I need to forward and backward only on some connections of the fully connected layer(hence the “shortlist”). Since the output size is large, I divide the embedding layer onto 2 GPUs.

Relevant parts of the code:

class SparseLinear(nn.Module):
    def __init__(self, num_labels, hidden_size, device_embeddings):
        super(SparseLinear, self).__init__()
        
        self.device_embeddings = device_embeddings
        
        self.input_size = hidden_size
        self.output_size = num_labels
        self.weight = Parameter(torch.Tensor(self.output_size, self.input_size))
        if bias:
            self.bias = Parameter(torch.Tensor(self.output_size, 1))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()
        self.sparse = True  # Required for optimizer

    def forward(self, embed, shortlist):
        short_weights = F.embedding(shortlist,
                                    self.weight,
                                    sparse=self.sparse)
        
        out = torch.matmul(embed.unsqueeze(1), short_weights.permute(0, 2, 1))
        
        short_bias = F.embedding(shortlist,
                                 self.bias,
                                 sparse=self.sparse)
        out = out + short_bias.permute(0, 2, 1)
        del short_weights
        return out.squeeze()

class DividedLinear(DeepXMLBase):
    def __init__(self, <params>):
    	# Say I have output size of 1000000, and I divide it into two 2 parts
    	self.label_partition_lengths = [(500000, "cuda:0"), (500000, "cuda:1")]
    	self.classifier = [SparseLinear(num_labels, 300, torch.device(device_name)) for 
                num_labels, device_name in self.label_partition_lengths]
        <init other params>
    
    def encode(self, batch_data):
        return self.transform(batch_data["doc_embeddings"].to(self.device_embeddings))  # is some network to transform embeddings
    
    
    def forward_with_error_calc(self, batch_data, criterion):
        print("before", torch.cuda.memory_allocated(1) / (1024 * 1024 * 1024), 
                      torch.cuda.memory_allocated(2) / (1024 * 1024 * 1024))
        encoded = self.encode(batch_data)
        device_embeddings = [torch.device(num_labels_device[1]) for num_labels_device in self.label_partition_lengths]
        
        shortlists = [x.to(device_embeddings[i]) for i, x in enumerate(batch_data["shortlist"])]
        encoded_replicate = [encoded.to(device_embeddings[i]) for i in range(len(device_embeddings))]
        
        outputs = nn.parallel.parallel_apply(self.classifier, list(zip(encoded_replicate, shortlists)))
        
        targets = [batch_data["shortlist_weights"][i].to(device_embeddings[i]) for i in range(len(device_embeddings))]
        errors = nn.parallel.parallel_apply(nn.parallel.replicate(criterion, device_embeddings), list(zip(outputs, targets)))
        errors_gather = nn.parallel.gather(errors, target_device=device_embeddings[0])
        
        total_error = errors_gather.sum()
        
        print("after", torch.cuda.memory_allocated(1) / (1024 * 1024 * 1024), 
                      torch.cuda.memory_allocated(2) / (1024 * 1024 * 1024))
        
        for output in outputs:
            del output
        for target in targets:
            del target
        for x in shortlists:
            del x
        for x in encoded_replicate:
            del x
        torch.cuda.empty_cache()
        print("after del", torch.cuda.memory_allocated(1) / (1024 * 1024 * 1024), 
                      torch.cuda.memory_allocated(2) / (1024 * 1024 * 1024))

        return total_error

But the GPUs run out of memory after some batches. Particularly, I observe behavior like this:

before 5.049468994140625 5.049465179443359
after 5.1367316246032715 5.1367268562316895
after del 5.1367316246032715 5.1367268562316895
before 5.136678695678711 5.1366729736328125
after 5.223941326141357 5.223934650421143
after del 5.223941326141357 5.223934650421143

So these seems to be some leakage in the forward_with_error_calc function, but I can’t figure out what it is. Can someone please help me in figuring this out? TIA.