They encode the sequence of images into a latent space z_0 = (p_0, q_0) where p and q are both k-dimensional vectors, corresponding to momentum and position at time t=0. z_0 is then transformed into a scalar value h with another network. Then, they compute z_1 = (p_1, q_1) where p_1 = p_0 - dh/dq_0 and q_1 = q_0 + dh/dp_0. z_1 is then fed to the decoder network and the reconstruction loss is used to train the whole system.
The issue is the following: how can I build a graph in pytorch that uses the gradients of a subset of the network to compute values (in this case p and q) that are used by the rest of the network? Does pytorch allow something like this?
You can calculate (subgraph) gradients during the forward pass. When doing this inside autograd.Function, you can plug precalculated gradients into the outer graph. Not sure if this will help with your case, but here is an illustrative snippet: