I am trying to accomplish something like the following:
z = encoder1(x, y)
theta = encoder2(x, y)
predMean = decoder(x, z, theta)
Where x and y are my data
In the loss function, I would like to introduce a term that is the derivative of predMean wrt x, considering z and theta constant.
However I think that
predMeanGrad = grad(
outputs=predMean,
inputs=x,
create_graph=True,
retain_graph=True,
grad_outputs=torch.ones(predMean.shape))[0]
Is actually applying the chain rule so that the derivative contains also the terms dz/dx, dtheta/dx.
My initial guess was to detach z and theta and compute the derivative of predMean; however I would need to reattach them when back-propagating the loss.
Are there more sophisticated solutions to your knowledge?