I am trying to train a tree lstm model (ref code here - https://github.com/ttpro1995/TreeLSTMSentiment/blob/master/model.py) and noticed that my memory usage is steadily increasing while training. To identify where the problem is occurring, I tried to use some repeated forward pass calls to see if that memory usage is increasing and it does. I have been trying for 2 days and have been unable to identify why the memory usage keeps increasing.

```
class treeEncoder(nn.Module):
def __init__(self, cuda,in_dim, mem_dim,wordVects,labels,labelMap,criterion,device):
super(treeEncoder, self).__init__()
self.cudaFlag = cuda
self.in_dim = in_dim
self.mem_dim = mem_dim
self.device = device
self.labels = labels
self.labelMap = labelMap
self.criterion = criterion
self.ix = nn.Linear(self.in_dim,self.mem_dim)
self.ih = nn.Linear(self.mem_dim,self.mem_dim)
self.fx = nn.Linear(self.in_dim,self.mem_dim)
self.fh = nn.Linear(self.mem_dim, self.mem_dim)
self.ux = nn.Linear(self.in_dim,self.mem_dim)
self.uh = nn.Linear(self.mem_dim,self.mem_dim)
self.ox = nn.Linear(self.in_dim,self.mem_dim)
self.oh = nn.Linear(self.mem_dim,self.mem_dim)
self.wordVects = wordVects
def forward(self,node):
loss = Variable(torch.zeros(1))
if self.cudaFlag:
loss = loss.to(self.device)
for i in range(node.num_children):
_, child_loss = self.forward(node.childrenList[i])
loss = loss + child_loss
child_c, child_h = self.getChildStates(node)
node.state = self.nodeForward(self.wordVects[node.uid].to(self.device),child_c,child_h)
label = Variable(torch.tensor(self.labelMap[node.label]))
loss = loss + self.criterion(output.reshape(-1,4), label.reshape(-1).to(self.device))
return node.state, loss
def nodeForward(self,x,child_c,child_h):
# h^~_j = sum of child hidden states
child_h_sum = torch.sum(child_h,0)
i = torch.sigmoid(self.ix(x) + self.ih(child_h_sum))
o = torch.sigmoid(self.ox(x)+self.oh(child_h_sum))
u = torch.tanh(self.ux(x)+self.uh(child_h_sum))
fx = self.fx(x)
f = torch.cat([self.fh(child_hi)+fx for child_hi in child_h], 0)
fc = torch.sigmoid(f)
c = i*u + torch.sum(fc,0)
h = o*torch.tanh(c)
return c,h
def getChildStates(self,node):
if node.num_children==0:
child_c = Variable(torch.zeros(1,self.mem_dim))
child_h = Variable(torch.zeros(1,self.mem_dim))
if self.cudaFlag:
child_c, child_h = child_c.to(self.device), child_h.to(self.device)
else:
child_c = Variable(torch.Tensor(node.num_children,self.mem_dim))
child_h = Variable(torch.Tensor(node.num_children,self.mem_dim))
if self.cudaFlag:
child_c, child_h = child_c.to(self.device), child_h.to(self.device)
for idx in range(node.num_children):
child_c[idx] = node.childrenList[idx].state[0]
child_h[idx] = node.childrenList[idx].state[1]
return child_c, child_h
model = treeEncoder('params needed')
preds = []
labels = []
for valSet in x_test:
finalTree = valSet[-1]
preds.append(model(finalTree.root)[1])
print(preds)
predTensor = torch.stack(preds)
```

Are there any variables that I am not clearing or anything. I have tried deleting variables with del and torch.cuda.empty_cache() and all kinds of things. Any help would be appreciated.