I am trying to accomplish something like the following:

```
z = encoder1(x, y)
theta = encoder2(x, y)
predMean = decoder(x, z, theta)
```

Where x and y are my data

In the loss function, I would like to introduce a term that is the derivative of predMean wrt x, considering z and theta constant.

However I think that

```
predMeanGrad = grad(
outputs=predMean,
inputs=x,
create_graph=True,
retain_graph=True,
grad_outputs=torch.ones(predMean.shape))[0]
```

Is actually applying the chain rule so that the derivative contains also the terms dz/dx, dtheta/dx.

My initial guess was to detach z and theta and compute the derivative of predMean; however I would need to reattach them when back-propagating the loss.

Are there more sophisticated solutions to your knowledge?