Hi,

The grad function can be used to calculate partial derivative according to this post. As per documentation,

If `create_graph=True` , graph of the derivative will be constructed, allowing to compute higher order derivative products.

and `retain_graph` will be automatically set to `True`. However, I’m a bit confused how it works, or in other words, what should the computation graph look like by using this option.

As a tractable example, let x be a known vector, f be a `torch.nn.Module` instance that returns a scalar, and θ be the parameters of f:

To implement this loss, which will be subsequently differentiated against θ, I think a good implementation should:

• not compute gradient with respect to θ when evaluating the partial gradient with respect to x
• have the two occurrences of x sharing the same data
• contain only one graph for f(x;θ)
• treat the partial derivative in the second term as a new function with respect to θ

The computation graph I’d like to achieve should appear as: (legend: node without border -> requires_grad=False; node in oval border -> requires_grad=True; node in rectangle border -> `torch.nn.Module` instance or similar entity in a computation graph; node in diamond border -> the scalar output i.e. the loss)

To attain the desired computation graph I may write the following code (version: `PyTorch-0.4.1`):

``````# `x` is the input such that x.requires_grad is False;
# `f` is the network
# As a specific example, assume that:
# x = torch.rand(3)
# f = torch.nn.Linear(3, 1)
import torch
# build graph for f(x;theta) without derivative
loss1 = f(x)
# only_inputs=True: so that gradient of f(x) with respect to theta is not evaluated here
# create_graph=True: so that `dfdx` can be used as intermediate node in other graph
# retain_graph=False: so that we won't have a duplicate graph for f(x)
dfdx, = grad(f(x_), x_, only_inputs=True, create_graph=True, retain_graph=False)
loss2 = torch.norm(dfdx, 1)
loss = loss1 + loss2
loss.backward()
``````

This code snippet certainly compiles. However, it’s quite difficult for me to directly verify the result is correct. My questions here are:

• How to visualize the computation graph created by `create_graph=True` in the above simple code?
• Does `retain_graph=False` do what I intend to do as shown in comment? What if I change to `retain_graph=True`?
• Is there a way to not compute `f(x)` twice while maintaining the same computation graph?