Error in training LSTM network - gradient propapgation through hidden variables!

I am trying to make an LSTM network from scratch using LSTMCell. 4 layers of LSTMs are stacked and a fully connected layer finally gives a single output for each time step. Here is the model code

class McG_LSTM(nn.Module):
    def __init__(self,size=25,BSIZE=32,LEN=512):
        super(McG_LSTM, self).__init__()
        
        self.size = size
        self.lstm_1 = nn.LSTMCell(1,self.size,True)
        self.lstm_2 = nn.LSTMCell(self.size,self.size,True)
        self.lstm_3 = nn.LSTMCell(self.size,self.size,True)
        self.lstm_4 = nn.LSTMCell(self.size,self.size,True)
        self.fc1 = nn.Linear(self.size, 1)

    def forward(self, x, hidden):
        self.ht1, self.ct1, self.ht2, self.ct2, self.ht3, self.ct3, self.ht4, self.ct4 = hidden 
        self.y.detach()
        self.y.zero_()
        
        for i in range(x.size(0)):
            self.ht1, self.ct1 = self.lstm_1(x[i],(self.ht1,self.ct1))
            self.ht2, self.ct2 = self.lstm_2(self.ht1,(self.ht2,self.ct2))
            self.ht3, self.ct3 = self.lstm_3(self.ht2,(self.ht3,self.ct3))
            self.ht4, self.ct4 = self.lstm_4(self.ht3,(self.ht4,self.ct4))
            self.y[i] = self.fc1(self.ht4)

        return self.y, (self.ht1, self.ct1, self.ht2, self.ct2, self.ht3, self.ct3, self.ht4, self.ct4)
    
    def init_hidden(self,BSIZE,LEN):
        self.register_buffer('ht1', torch.zeros(BSIZE,self.size))
        self.register_buffer('ct1', torch.zeros(BSIZE,self.size))
        self.register_buffer('ht2', torch.zeros(BSIZE,self.size))
        self.register_buffer('ct2', torch.zeros(BSIZE,self.size))
        self.register_buffer('ht3', torch.zeros(BSIZE,self.size))
        self.register_buffer('ct3', torch.zeros(BSIZE,self.size))
        self.register_buffer('ht4', torch.zeros(BSIZE,self.size))
        self.register_buffer('ct4', torch.zeros(BSIZE,self.size))
        self.register_buffer('y', torch.zeros(LEN,BSIZE,1))

        return (self.ht1, self.ct1, self.ht2, self.ct2, self.ht3, self.ct3, self.ht4, self.ct4)

In the main code, I use the following relevant code snippet to train the network:

def repackage_hidden(h):
    if isinstance(h, torch.Tensor):
        return h.detach()
    else:
        return tuple(repackage_hidden(v) for v in h)

for epoch in range(1,NUMEPOCHS+1):
    model.train()
    count = 0
    for data_input, data_output in train_loader:    
        optimizer.zero_grad()
        data_input = data_input.transpose(0,1).transpose(0,2).to(device, dtype=torch.float)  
        data_output = data_output.transpose(0,1).transpose(0,2).to(device, dtype=torch.float)
        
        hidden = repackage_hidden(hidden)
        output, hidden = model(data_input,hidden) 
        loss = criterion(output, data_output)
        loss.backward()
        optimizer.step()
        count += 1

However, I get the following error:

Traceback (most recent call last):
  File "McG-Experiment.py", line 76, in <module>
    loss.backward()
  File "/home/sc/anaconda3/lib/python3.8/site-packages/torch/tensor.py", line 185, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/home/sc/anaconda3/lib/python3.8/site-packages/torch/autograd/__init__.py", line 125, in backward
    Variable._execution_engine.run_backward(
RuntimeError: Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling backward the first time.

Note that the above training code is able to execute the loop for the first loop (count=0) but fails on the second loop (count =1). I can not understand what the problem is? I made sure to detach the buffer variables!

PS: Any suggestions to optimize the code is also welcome

I haven’t tested your code, but note that self.y.detach() is not an inplace operation so you might need to either reassign the detached tensor via self.y = self.y.detach() or use the inplace method self.y.detach_().

1 Like

@ptrblck Yep, that was exactly what I was doing wrong. Thanks a lot for pointing out the problem!

On a side note, is this how people normally write time series models? I mean detaching the temporary variables and writing so many register_buffers seem quite inelegant!

It depends on your use case, if you need to register these tensors as buffers, i.e. if they should be part of the state_dict and be pushed to the specified devices via .to().

Regarding the detaching: yes, you would either repackage/recreate the hidden states, detach them, or keep them attached, as it allows you to implement different use cases and workflows (e.g. automatic detaching of tensors would limit the usability and would disallow certain use cases).

1 Like

Yea, I thought it’s a good idea to buffer those tensors as otherwise, they won’t be transferred to the GPU - this would mean GPU to CPU transfer of data during the update of the hidden variables which seems inefficient.