Shared gradient memory during backpropagation through time (BPTT)

Hello,

if the loss is per time-step you could

  • tell the cnn you won’t need parameter gradients: for p in cnn.parameters(): p.requires_grad = False,
  • run the LSTM on the input, capture output in one long array, say, lstm_out. Let’s pretend it is sequence-first.
  • apply the CNN to each timestep-slice wrapped in a new Variable, so out_step = cnn(Variable(lstm_out.data[i], requires_grad=True)),
  • compute the loss per step
    loss_step = loss_function(out_step, target),
  • backprop through the cnn with loss_step.backward(),
  • append gradient for timestep to a list, say, lstm_out_grad.append(out_step.grad)
  • backprop through the LSTM withlstm_out.backward(torch.stack(lstm_out_grad)).

This should give you the gradients in the LSTM, except for any bugs that come with code only typed and not tested. :slight_smile:

Here is a minimal snippet how to manually break backprop

Best regards

Thomas

3 Likes