I am currently building a combined CNN / LSTM model, where the CNN builds an input feature vector for each frame in a video sequence. Then, I loop over these input feature vectors, and feed them into an LSTMCell.
Now, after each LSTM step, I extract the last 4 elements of the hidden state, save these in a list and use them to calculate a custom gradient, which is then fed back to the hidden state after T frames.
Now, after the first sequence, I think the hidden state doesn’t get freed from the graph, as subsequent calls to .backward(gradient) don’t seem to have an effect.
lstm = nn.LSTMCell(4096 + num_coords, hidden_size) hidden = (autograd.Variable(torch.zeros(1, hidden_size)), autograd.Variable(torch.zeros(1, hidden_size))) if cuda_: print('CUDA enabled.') net.cuda() lstm.cuda() optimizer = optim.Adam(list(lstm.parameters()), lr = args.learning_rate) def detach(states): return [state.detach() for state in states] if __name__ == '__main__': for u in range(args.epochs): for i, video in tqdm(enumerate(dataloader)): for j in range(frames_per_video):
At this point, we build the input feature vectors later denoted as lstm_input using a pretrained CNN.
input_lstm = build_feature_vector_using_a_CNN() hidden = (autograd.Variable(torch.zeros(1, hidden_size)), autograd.Variable(torch.zeros(1, hidden_size))) optimizer.zero_grad() hidden = detach(hidden) for u in range(len(lstm_input)): hidden = lstm(lstm_input[u, :], hidden) mu = hidden[0, -4:] # extract last elements of each hidden layer # save these mu in a list # CALCULATE CUSTOM_GRAD EXTERNALLY if j==0: print('Gradient: ', CUSTOM_GRAD) # THIS GRADIENT IS LARGE hidden[0, -num_coords:].backward(CUSTOM_GRAD) for f in lstm.parameters(): print('grad sum is') print(np.sum(f.grad.cpu().data.numpy())) # THIS GRADIENT IS DECLINING optimizer.step() optimizer.zero_grad()
I can’t seem to solve this. My guess is that there is some element in my “external computation” that prohibits the hidden state to be detached. Why, though, does detaching not have an effect?
hidden = detach(hidden)
Thanks for any help.