How does backward propagation work in a nested neural network

Hi, everyone. I am working on a dynamic programing problem in which a nested neural network is implemented. For the sake of better understanding my question, simply assume the loss is

L = G(k’) - H(k’')

where G and H are two functions we do not need to know but variable k’ and k’’ stem from neural network NNs such that

k’ = NNs(k), k’’ = NNs(k’)

so k is the original input. Apparently, to obtain output k’’ the same neural network gets nested once:

k → NNs → k′ → NNs → k′′

My question is simple:

(1) how does backward propagation work in this nested neural network setting (one backpropagation for each transition I guess)? does it work the same as the regular one?
(2) could this nested NNs setup slow down the computation?
(3) what could go wrong under this nested NNs setup?

  1. Autograd will create a computation graph during the forward pass and will use it to backpropagate through it during the backward call.
  2. Yes, increasing the computation graph adds more operations to the forward and backward passes and increases the memory usage.
  3. Depending on how the entire recursion is used you might run into issues where stale forward activations are used during a backward pass. However, if PyTorch doesn’t raise any errors your code should be fine.

Thanks for the heroic answer. My code produce NaN loss when I increase the dimensionality of NNs output, lol. I have one more question, if I may ask, what is the potential challenge except for the issues concerning stale forward activations.