Training Stateful LSTM in Pytorch cause runtime error

class LSTMPricePredictor(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers):
        super().__init__()
        self.input_dim = input_size
        self.hidden_dim = hidden_size
        self.num_layers = num_layers
        
        self.init_hidden()
        
        self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True)
        self.hidden_to_output = nn.Linear(hidden_size, output_size)
    
    def init_hidden(self):
        self.hidden = (torch.zeros((self.num_layers, 1, self.hidden_dim), requires_grad=True), torch.zeros((self.num_layers, 1, self.hidden_dim)))
    
    def forward(self, inputs):
        x, (hx, cx) = inputs
        lstm_out, (hx, cx) = self.lstm(x, (hx, cx))
        #out = self.hidden_to_output(lstm_out.view(-1, self.hidden_dim))
        out = self.hidden_to_output(lstm_out)
        return out, (hx, cx)

class StatefullLSTMPricePredictor(LSTMPricePredictor):
    def forward(self, input_batches):
        linear_layers_results = []
        self.hidden[0].detach_()
        for i, input_batch in enumerate(input_batches):
            lstm_out, self.hidden = self.lstm(input_batch.unsqueeze(0), self.hidden)
            linear_layers_results.append(self.hidden_to_output(lstm_out[:, -1:, :]))
        return torch.cat(linear_layers_results, dim=0)

model = StatefullLSTMPricePredictor(4, HIDDEN_SIZE, 4, 1)

criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters())
for i in range(3, len(features), 3):
    criterion.zero_grad()
    out = model(features[:i])
    loss = criterion(out, labels[:i])
#     loss.backward() if i == 3 else loss.backward(retain_graph=True)
    loss.backward()
    optimizer.step()

Evaluating these code to train stateful LSTM cause following error:
RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

Can use please explain what is wrong?

1 Like

Is there any way to detect which model parameters are freed and cause such problem?

It might interest you to know that I’ve been trying to do something similar myself: Confusion regarding PyTorch LSTMs compared to Keras stateful LSTM

Although I’m not sure if just wrapping the previous hidden data in a torch.Variable ensures that stateful training works

1 Like

I think you need to detach both hiddens because the hiddens that are output from the LSTM will require grad.
I would recommend doing so when you actually store them rather than before (but probably with .detach() rather than the inplace variant).

Best regards

Thomas

2 Likes