Hi all! I’m a novice of PyTorch. Recently, I want to implement a special deep neural network, but I didn’t find any related PyTorch examples to help realize it. This special network has the following properties:

Each layer has its own loss function, and the parameters and hidden representation at current layer are learned by minimizing this loss. It should be noticed that there is no closedform expression for the hidden representation, so here we have to perform an iterative optimization process to infer representation.

The optimization process for one layer only affects the parameters and representation of this corresponding layer, rather than other layers.

For the ith layer, the update of its parameters W^{i} and representation y^{i} depends on y^{i1} and y^{i+1}.
I don’t know how to build the computational graph for such a model. Any comments and suggestions would be appreciated. And it would be better to give some related code examples or links. Many thanks.