When to detach the hidden layer for stateful LSTMs?

I’d like to build a stateful LSTM but I receive the runtime error “Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.”

This is to do with not being able to perform backprop due to results being cleared to save memory. This topic is covered in these two threads and albanD describes what is happening clearly.

The two solutions are retaining the computational graph (which I don’t want to do) and detaching the hidden layer between batches.

class stateful_LSTM(nn.Module):
    """A Long Short Term Memory network
    model"""
        
    def __init__(self, num_features, hidden_dim, output_dim,
                 batch_size, series_length, device, 
                 dropout=0.1, num_layers=2, debug=True):
        
        super(stateful_LSTM, self).__init__()
        
        # Number of features
        self.num_features = num_features
        
        # Hidden dimensions
        self.hidden_dim = hidden_dim

        # Number of hidden layers
        self.num_layers = num_layers
        
        # The output dimensions
        self.output_dim = output_dim
        
        # Batch Size
        self.batch_size = batch_size
        
        # Length of sequence
        self.series_length = series_length
        
        # Dropout Probability
        self.dropout = dropout
        
        # CPU or GPU
        self.device = device
        
        # Define the LSTM layer
        self.lstm = nn.LSTM(
            input_size = self.num_features, 
            hidden_size =self.hidden_dim,
            dropout = self.dropout ,
            num_layers =self.num_layers)

        # Fully Connected Layer
        self.fc1 = nn.Linear(in_features=self.hidden_dim, 
                             out_features=self.hidden_dim)
        
        # Activation function
        self.act = nn.ReLU()
        
        # Output layer
        self.out = nn.Linear(in_features=self.hidden_dim, 
                             out_features=self.output_dim)
        
        self.hidden = self.init_hidden()
        
        
    def init_hidden(self):
        """Initialised the hidden state to be zeros"""
        return (torch.zeros(self.num_layers, self.batch_size, self.hidden_dim).to(self.device),
                torch.zeros(self.num_layers, self.batch_size, self.hidden_dim).to(self.device))
    
    
    def forward(self, x):
        """Forward pass through the neural network"""
        # Adjust to a variable batch size 
        batch_size = x.size()[0]
        series_length = x.size()[1]
        
        # Keeps the dimensions constant regardless of batchsize
        x = x.contiguous().view(series_length, batch_size, -1) 
        
        # Pass through through lstm layer
        x, self.hidden = self.lstm(x, self.hidden)
        
        x = x[-1]  
        
        # Fully connected hidden layer
        x = self.act(self.fc1(x))
        
        return self.out(x)

I am slightly confused as to when this detach is meant to occur. I’ve tried detaching it straight after the self.LSTM layer but that doesn’t work. Could someone please explain when (and why then) you should detach the hidden layer?

Thanks!