I am trying to accomplish something like the following:
z = encoder1(x, y) theta = encoder2(x, y) predMean = decoder(x, z, theta)
Where x and y are my data
In the loss function, I would like to introduce a term that is the derivative of predMean wrt x, considering z and theta constant.
However I think that
predMeanGrad = grad( outputs=predMean, inputs=x, create_graph=True, retain_graph=True, grad_outputs=torch.ones(predMean.shape))
Is actually applying the chain rule so that the derivative contains also the terms dz/dx, dtheta/dx.
My initial guess was to detach z and theta and compute the derivative of predMean; however I would need to reattach them when back-propagating the loss.
Are there more sophisticated solutions to your knowledge?