Shared gradient memory during backpropagation through time (BPTT)

I’m trying to combine LSTM (scratch) and CNN (pre-trained) as follows:

Input -> LSTM -> CNN -> Output (t=1)
|
Input -> LSTM -> CNN -> Output (t=2)
|
Input -> LSTM -> CNN -> Output (t=3)
|

The recurrent connection is exist only in LSTM not CNN.

I use pre-trained CNN, so that I don’t need to train CNN.
Only what I need is the gradient from CNN to train LSTM.
Because CNN requires a lot of memory, I want to share gradient memory for BPTT.

For example, the gradient from CNN is computed and saved for LSTM training at each time, but internal gradient memory (or buffer?) of CNN is shared for next step.

If I follow the conventional RNN training code (forward (t=1,2,3, … , T) and backward (t=T, …, 3,2,1)), I think that dynamic graph of pytorch will allocate memory of CNN separately over time.

How can I handle this problem?
I will appreciate if someone give me any answer.

1 Like

Hello,

if the loss is per time-step you could

  • tell the cnn you won’t need parameter gradients: for p in cnn.parameters(): p.requires_grad = False,
  • run the LSTM on the input, capture output in one long array, say, lstm_out. Let’s pretend it is sequence-first.
  • apply the CNN to each timestep-slice wrapped in a new Variable, so out_step = cnn(Variable(lstm_out.data[i], requires_grad=True)),
  • compute the loss per step
    loss_step = loss_function(out_step, target),
  • backprop through the cnn with loss_step.backward(),
  • append gradient for timestep to a list, say, lstm_out_grad.append(out_step.grad)
  • backprop through the LSTM withlstm_out.backward(torch.stack(lstm_out_grad)).

This should give you the gradients in the LSTM, except for any bugs that come with code only typed and not tested. :slight_smile:

Here is a minimal snippet how to manually break backprop

Best regards

Thomas

3 Likes

Dear Thomas,

It seems really cool solution!!
I will try as you suggested.

Thank you very much for your helping.

Best,

MInju Jung

Did you manage to get the gradient at each time step?