Increasing memory usage for forward pass

I am trying to train a tree lstm model (ref code here - https://github.com/ttpro1995/TreeLSTMSentiment/blob/master/model.py) and noticed that my memory usage is steadily increasing while training. To identify where the problem is occurring, I tried to use some repeated forward pass calls to see if that memory usage is increasing and it does. I have been trying for 2 days and have been unable to identify why the memory usage keeps increasing.

class treeEncoder(nn.Module):
    def __init__(self, cuda,in_dim, mem_dim,wordVects,labels,labelMap,criterion,device):
        super(treeEncoder, self).__init__()
        self.cudaFlag = cuda
        self.in_dim = in_dim
        self.mem_dim = mem_dim
        self.device = device
        self.labels = labels
        self.labelMap = labelMap
        self.criterion = criterion

        self.ix = nn.Linear(self.in_dim,self.mem_dim)
        self.ih = nn.Linear(self.mem_dim,self.mem_dim)

        self.fx = nn.Linear(self.in_dim,self.mem_dim)
        self.fh = nn.Linear(self.mem_dim, self.mem_dim)

        self.ux = nn.Linear(self.in_dim,self.mem_dim)
        self.uh = nn.Linear(self.mem_dim,self.mem_dim)

        self.ox = nn.Linear(self.in_dim,self.mem_dim)
        self.oh = nn.Linear(self.mem_dim,self.mem_dim)
        
        self.wordVects = wordVects
    
    def forward(self,node):
        loss = Variable(torch.zeros(1))
        
        if self.cudaFlag:
            loss = loss.to(self.device)
        
        for i in range(node.num_children):
            _, child_loss = self.forward(node.childrenList[i])
            loss = loss + child_loss
        child_c, child_h = self.getChildStates(node)
        node.state = self.nodeForward(self.wordVects[node.uid].to(self.device),child_c,child_h)
        
        label = Variable(torch.tensor(self.labelMap[node.label]))
            
        loss = loss + self.criterion(output.reshape(-1,4), label.reshape(-1).to(self.device))
        
        return node.state, loss
        
    def nodeForward(self,x,child_c,child_h):
        # h^~_j = sum of child hidden states
        child_h_sum = torch.sum(child_h,0)

        i = torch.sigmoid(self.ix(x) + self.ih(child_h_sum))
        o = torch.sigmoid(self.ox(x)+self.oh(child_h_sum))
        u = torch.tanh(self.ux(x)+self.uh(child_h_sum))
        
        fx = self.fx(x)
        f = torch.cat([self.fh(child_hi)+fx for child_hi in child_h], 0)
        fc = torch.sigmoid(f)
        
        c = i*u + torch.sum(fc,0)
        h = o*torch.tanh(c)
        
        return c,h
    
    def getChildStates(self,node):
        if node.num_children==0:
            child_c = Variable(torch.zeros(1,self.mem_dim))
            child_h = Variable(torch.zeros(1,self.mem_dim))
            if self.cudaFlag:
                child_c, child_h = child_c.to(self.device), child_h.to(self.device)
        
        else:
            child_c = Variable(torch.Tensor(node.num_children,self.mem_dim))
            child_h = Variable(torch.Tensor(node.num_children,self.mem_dim))
            if self.cudaFlag:
                child_c, child_h = child_c.to(self.device), child_h.to(self.device)
            
            for idx in range(node.num_children):
                child_c[idx] = node.childrenList[idx].state[0]
                child_h[idx] = node.childrenList[idx].state[1]
        return child_c, child_h
model = treeEncoder('params needed')

preds = []
labels = []

for valSet in x_test:
    finalTree = valSet[-1]
    preds.append(model(finalTree.root)[1])
    print(preds)
    predTensor = torch.stack(preds)

Are there any variables that I am not clearing or anything. I have tried deleting variables with del and torch.cuda.empty_cache() and all kinds of things. Any help would be appreciated.

1 Like

In your loop over x_test it seems you are appending the losses into preds.
Since you are neither detaching the loss nor wrap the code in with torch.no_grad(), each loss tensor will hold to its computation graph, which will increase the memory usage for each iteration.
If you don’t need to call backward on these losses, just use this code:

with torch.no_grad():
    for calSet in x_test:
        finalTree = valSet[-1]
        preds.append(model(finalTree,root)[1])
    predTensor = torch.stack(preds)
1 Like

@ptrblck This is exactly what I was thinking. In reality I also need to do the backward step as well. Each validation step pred will be reinitialized so I don’t believe it will cause too big of a problem.

Now the main question I have here is this. As far as I understand, the loss and computation graph are particular to that 1 forward pass right. Now do I reset this and remove the memory they consume after each forward step.

You could store each loss using preds.append(loss.detach()) and see, if the memory growth disappears.

The backward call will clean the computation graph and its intermediate tensors, if you don’t run it with retain_graph=True.
As long as some tensors hole a reference to the computation graph, they are stored and use memory.