LSTM Stateful batch size != 1

I want to train an LSTM neural network similar to how I do it in Keras with stateful = True. The goal is to be able to transmit the states between the sequences of the same batch and between the sequences of different batches. This is the class I use for the LSTM module:

class LSTMStateful(nn.Module):
    
    def __init__(self, input_size, hidden_size, batch_size,
                 bidirectional=False, **kwargs):
        
        super().__init__()
        self._hidden_state, self._hidden_cell = (None, None)
        self._batch_size = batch_size
        self._hidden_size = hidden_size
        self._bidirectional = bidirectional
        self.lstm = nn.LSTM(input_size, hidden_size,
                            bidirectional=bidirectional, **kwargs)
        self.reset_hidden_cell()
        self.reset_hidden_state()
        
    @property
    def batch_size(self):
        return self._batch_size
        
    @property
    def bidirectional(self):
        return self._bidirectional
        
    @property
    def hidden_size(self):
        return self._hidden_size
    
    @property    
    def hidden_cell(self):
        return self._hidden_cell
    
    @property    
    def hidden_state(self):
        return self._hidden_state
    
    def reset_hidden_cell(self):

        self._hidden_cell = torch.zeros(self.lstm.num_layers * (self.bidirectional + 1), 
                                         self.batch_size, self.hidden_size)
        
    def reset_hidden_state(self):
        self._hidden_state = torch.zeros(self.lstm.num_layers * (self.bidirectional + 1), 
                                        self.batch_size, self.hidden_size)   
        
    def forward(self, input_seq):
        lstm_out, (self._hidden_cell, self._hidden_state) = self.lstm(input_seq, 
                                                                      (self._hidden_cell, self._hidden_state))
        return lstm_out, (self._hidden_cell, self._hidden_state)

So I create a simple two-layer stacked model to compare it to the Keras model:

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.lstm = LSTMStateful(input_size=1, batch_size=BATCH_SIZE,
                                 hidden_size=HIDDEN_SIZE,
                                 num_layers=2, batch_first=True)
        
        self.linear = nn.Linear(HIDDEN_SIZE, 1)
        
    def forward(self, x):
        output, (hidden_cell, hidden_state) = self.lstm(x)
        output = self.linear(output[:,-1:,:]) #return_sequences = False in Keras
        return output

So I train the model. I use 1750 sequences of length 200 that are divided into batches of 250, that is (250, 200, 1) with batch_first = True

model = MyModel()
loss_function = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
torch.autograd.set_detect_anomaly(True)

epochs = 1000
train_loss = []
eval_loss = []
for i in range(epochs):
    
    model.zero_grad()
    model.train()
    for x_train, y_train in zip(X_train, Y_train):
       # x_train shape (250, 200, 1)
       # y_train shape (250, 200, 1)
        y_pred = model(x_train)
        single_loss = loss_function(y_pred, y_train)
        loss_train_value=single_loss.item()
        train_loss.append(loss_train_value)
        
    single_loss.backward()
    optimizer.step()
    
    model.lstm.reset_hidden_cell()
    model.lstm.reset_hidden_state()
    
    
    model = model.eval()
    with torch.no_grad():
        for x_eval, y_eval in zip(X_eval, Y_eval):
            y_pred_eval = model(x_eval)
            single_loss_eval = loss_function(y_pred_eval, y_eval)
            loss_eval_value=single_loss_eval.item()
            eval_loss.append(loss_eval_value)


    print(f'epoch: {i:3} loss: {loss_train_value:10.8f}  eval loss: {loss_eval_value:10.8f}')

This takes around 30 seconds per epoch while the model in Keras (which I consider similar takes 5 seconds per epoch), in addition the loss of validation is too great compared to that obtained in Keras. (Both Keras and pytorch use the same training and validation data) Leaving aside the absence of initializations of weights and bias it is evident that I am doing something wrong. Could you help me?

You are currently only calculating the gradients using the last batch via single_loss.
Is this your use case or would you like to call single_loss.backward() and optimizer.step() inside the data loop?
If that’s the case, you might need to call detach() on the states to avoid trying to backpropagate multiple times through the same variables.

1 Like

You are right, I want to update the weights in each batch. I’m not sure which way should I call detach ().
I have tried to use the following code:

for i in range(epochs):
    
    for x_train, y_train in zip(X_train, Y_train):
        model.train()
        model.zero_grad()
        y_pred = model(x_train)
        single_loss = loss_function(y_pred, y_train)
        loss_train_value=single_loss.item()
        train_loss.append(loss_train_value)
        
        single_loss.backward()  
        optimizer.step()


    
    model.lstm.reset_hidden_cell()
    model.lstm.reset_hidden_state()


I receive the following error code:

RuntimeError                              Traceback (most recent call last)
<ipython-input-22-a544b1c4beff> in <module>
     20         train_loss.append(loss_train_value)
     21 
---> 22         single_loss.backward()
     23         optimizer.step()
     24 

~/anaconda3/lib/python3.7/site-packages/torch/tensor.py in backward(self, gradient, retain_graph, create_graph)
    196                 products. Defaults to ``False``.
    197         """
--> 198         torch.autograd.backward(self, gradient, retain_graph, create_graph)
    199 
    200     def register_hook(self, hook):

~/anaconda3/lib/python3.7/site-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)
     98     Variable._execution_engine.run_backward(
     99         tensors, grad_tensors, retain_graph, create_graph,
--> 100         allow_unreachable=True)  # allow_unreachable flag
    101 
    102 

RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

However I do not receive this error code if I use

model.lstm.reset_hidden_cell()
model.lstm.reset_hidden_state()

After each batch and not only one time for each epoch.

for i in range(epochs):
    
    for x_train, y_train in zip(X_train, Y_train):
        model.train()
        model.zero_grad()
        y_pred = model(x_train)
        single_loss = loss_function(y_pred, y_train)
        loss_train_value=single_loss.item()
        train_loss.append(loss_train_value)
        
        single_loss.backward()  
        optimizer.step()


    
        model.lstm.reset_hidden_cell()
        model.lstm.reset_hidden_state()

However, this implies that the state between batches is not maintained, so this is not the solution.

I have tried to specify retain_graph=True , single_loss.backward(retain_graph=True)
for all batch except the last one.
But this doesn’t seem to work either.

Instead of resetting the states, try to detach them via:

model.lstm._hidden_state = model.lstm._hidden_state.detach()
model.lstm._hidden_cell = model.lstm._hidden_cell.detach()

to keep the state values, but to detach them from the previous calculations.

1 Like