Memory Leakage when updating specific elements in Tensor

I have made a partially connected layer, in which input layer and output layer are not completely connected according to adjacency matrix.

input layer [in_features]
weight matrix [out_features,in_features] : only specific element is non-zero and learnable.
bias [out_features]
output layer [out_features]

However, It has a memory leakage problem.

How can I index and update specific elements in tensor without making a new tensor.

Below is my code.

from torch.nn.parameter import Parameter
import torch.nn.init as init
import torch.nn.functional as F
import torch.nn as nn
import math
import numpy as np
import torch

def kaiming_uniform_fan_(tensor, fan=None, a=0, nonlinearity='leaky_relu'):
    gain = init.calculate_gain(nonlinearity, a)
    std = gain / math.sqrt(fan)
    bound = math.sqrt(3.0) * std  # Calculate uniform bounds from standard deviation
    with torch.no_grad():
        return tensor.uniform_(-bound, bound)

class Linear_Partial(nn.Module):
    __constants__ = ['bias', 'adjacency']

    def __init__(self, adjacency,bias=True):
        super(Linear_Partial, self).__init__()
        # adjacency matrix: non-zero indice are only learnable    
        assert ((adjacency!=1) & (adjacency!=0)).sum()==0, 'value not in [0,1]'
        self.adjacency = Parameter(torch.Tensor(adjacency))
        # elements in weights which are learnable.

        #self.weight[adjacency==1] = self.weight_learnable
        if bias:
            self.bias = Parameter(torch.Tensor(self.out_features))
            self.register_parameter('bias', None)
    # since weight is not registered as parameter, we have to update weight value manually    
    def update_weight_manually(self):
        self.weight[self.adjacency==1] = self.weight_learnable
    def reset_parameters(self):
        kaiming_uniform_fan_(self.weight_learnable, fan=self.in_features, a=math.sqrt(5))
        #self.weight[self.adjacency==0] = 0
        if self.bias is not None:
            fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            init.uniform_(self.bias, -bound, bound)
    def forward(self, input):
        return F.linear(input, self.weight, self.bias)

    def extra_repr(self):
        return 'in_features={}, out_features={}, bias={}'.format(
            self.in_features, self.out_features, self.bias is not None



device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

def init_weights(m,device):
    if type(m) == Linear_Partial:
        print("Sent to {}: {}".format(device,m))

if True:
    model.apply(lambda m: init_weights(m,device))


Error message

RuntimeError                              Traceback (most recent call last)
<ipython-input-3-4099065c43f2> in <module>
----> 1 output_sample=model(input_sample)

~/tools/miniconda3/envs/pytorch/lib/python3.6/site-packages/torch/nn/modules/ in __call__(self, *input, **kwargs)
    539             result = self._slow_forward(*input, **kwargs)
    540         else:
--> 541             result = self.forward(*input, **kwargs)
    542         for hook in self._forward_hooks.values():
    543             hook_result = hook(self, input, result)

<ipython-input-1-fb1702069a84> in forward(self, input)
     58     def forward(self, input):
---> 59         self.update_weight_manually()
     60         return F.linear(input, self.weight, self.bias)

<ipython-input-1-fb1702069a84> in update_weight_manually(self)
     45     def update_weight_manually(self):
---> 46         self.weight[self.adjacency==1] = self.weight_learnable
     48     def reset_parameters(self):

RuntimeError: CUDA out of memory. Tried to allocate 2.98 GiB (GPU 1; 7.77 GiB total capacity; 2.24 GiB already allocated; 2.02 GiB free; 2.80 GiB cached)


I guess you the problem is your update_weight_manually function: It is always updated in a differentiable manner, meaning that it will keep track of all it’s history.
You want to make sure that at each forward, you work with a self.weight that has no history.
You want to make sure this does not happen with:

    def update_weight_manually(self):
        self.weight = self.weight.detach()
        self.weight[self.adjacency==1] = self.weight_learnable

Also your current implementation might keep non-zero values from one forward to the next if the adjacency changes. If the adjacency can change, you might want to replace the detach by self.weight=torch.zeros([self.out_features,self.in_features]).