Use gradient in intermediate step in encoder-decoder

I am trying to reproduce Hamiltonian Generative Networks.

They encode the sequence of images into a latent space z_0 = (p_0, q_0) where p and q are both k-dimensional vectors, corresponding to momentum and position at time t=0. z_0 is then transformed into a scalar value h with another network. Then, they compute z_1 = (p_1, q_1) where p_1 = p_0 - dh/dq_0 and q_1 = q_0 + dh/dp_0. z_1 is then fed to the decoder network and the reconstruction loss is used to train the whole system.

The issue is the following: how can I build a graph in pytorch that uses the gradients of a subset of the network to compute values (in this case p and q) that are used by the rest of the network? Does pytorch allow something like this?

You can calculate (subgraph) gradients during the forward pass. When doing this inside autograd.Function, you can plug precalculated gradients into the outer graph. Not sure if this will help with your case, but here is an illustrative snippet:

class LpdfGenAsymLaplace(torch.autograd.Function):
	@staticmethod
	def forward(ctx, x_in, p0, g_in):
		if any(ctx.needs_input_grad):
			with torch.enable_grad():
				x = x_in.detach().requires_grad_()
				g = g_in.detach().requires_grad_()
				lpdf = _pdf(x,p0,g)
				dLpdf_dX, dLpdf_dG = torch.autograd.grad(lpdf, (x,g), torch.ones_like(lpdf), retain_graph=True, create_graph=True)
				ctx.save_for_backward(dLpdf_dX, dLpdf_dG)
				ctx.mark_non_differentiable(p0)
				return lpdf
		else:
			return _pdf(x_in,p0,g_in)
		
	@staticmethod
	def backward(ctx, gr_in):
		dLpdf_dX, dLpdf_dG = ctx.saved_tensors
		return dLpdf_dX * gr_in, None, dLpdf_dG * gr_in