Hello,
if the loss is per time-step you could
- tell the cnn you won’t need parameter gradients:
for p in cnn.parameters(): p.requires_grad = False
, - run the LSTM on the input, capture output in one long array, say,
lstm_out
. Let’s pretend it is sequence-first. - apply the CNN to each timestep-slice wrapped in a new Variable, so
out_step = cnn(Variable(lstm_out.data[i], requires_grad=True))
, - compute the loss per step
loss_step = loss_function(out_step, target)
, - backprop through the cnn with
loss_step.backward()
, - append gradient for timestep to a list, say,
lstm_out_grad.append(out_step.grad)
- backprop through the LSTM with
lstm_out.backward(torch.stack(lstm_out_grad))
.
This should give you the gradients in the LSTM, except for any bugs that come with code only typed and not tested.
Here is a minimal snippet how to manually break backprop
Best regards
Thomas