I’d like to build a stateful LSTM but I receive the runtime error “Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.”
This is to do with not being able to perform backprop due to results being cleared to save memory. This topic is covered in these two threads and albanD describes what is happening clearly.
The two solutions are retaining the computational graph (which I don’t want to do) and detaching the hidden layer between batches.
class stateful_LSTM(nn.Module): """A Long Short Term Memory network model""" def __init__(self, num_features, hidden_dim, output_dim, batch_size, series_length, device, dropout=0.1, num_layers=2, debug=True): super(stateful_LSTM, self).__init__() # Number of features self.num_features = num_features # Hidden dimensions self.hidden_dim = hidden_dim # Number of hidden layers self.num_layers = num_layers # The output dimensions self.output_dim = output_dim # Batch Size self.batch_size = batch_size # Length of sequence self.series_length = series_length # Dropout Probability self.dropout = dropout # CPU or GPU self.device = device # Define the LSTM layer self.lstm = nn.LSTM( input_size = self.num_features, hidden_size =self.hidden_dim, dropout = self.dropout , num_layers =self.num_layers) # Fully Connected Layer self.fc1 = nn.Linear(in_features=self.hidden_dim, out_features=self.hidden_dim) # Activation function self.act = nn.ReLU() # Output layer self.out = nn.Linear(in_features=self.hidden_dim, out_features=self.output_dim) self.hidden = self.init_hidden() def init_hidden(self): """Initialised the hidden state to be zeros""" return (torch.zeros(self.num_layers, self.batch_size, self.hidden_dim).to(self.device), torch.zeros(self.num_layers, self.batch_size, self.hidden_dim).to(self.device)) def forward(self, x): """Forward pass through the neural network""" # Adjust to a variable batch size batch_size = x.size() series_length = x.size() # Keeps the dimensions constant regardless of batchsize x = x.contiguous().view(series_length, batch_size, -1) # Pass through through lstm layer x, self.hidden = self.lstm(x, self.hidden) x = x[-1] # Fully connected hidden layer x = self.act(self.fc1(x)) return self.out(x)
I am slightly confused as to when this detach is meant to occur. I’ve tried detaching it straight after the self.LSTM layer but that doesn’t work. Could someone please explain when (and why then) you should detach the hidden layer?