I am trying to reproduce Hamiltonian Generative Networks.
They encode the sequence of images into a latent space
z_0 = (p_0, q_0) where
q are both k-dimensional vectors, corresponding to momentum and position at time
z_0 is then transformed into a scalar value
h with another network. Then, they compute
z_1 = (p_1, q_1) where
p_1 = p_0 - dh/dq_0 and
q_1 = q_0 + dh/dp_0.
z_1 is then fed to the decoder network and the reconstruction loss is used to train the whole system.
The issue is the following: how can I build a graph in pytorch that uses the gradients of a subset of the network to compute values (in this case
q) that are used by the rest of the network? Does pytorch allow something like this?