PyTorch Autograd for Regression

another PyTorch newbie here trying to understand their computational graph and autograd.

I’m learning the following model on potential energy and corresponding force.

model = nn.Sequential(
    nn.Linear(1, 32),
    nn.Linear(32, 32), nn.Tanh(),
    nn.Linear(32, 32), nn.Tanh(),
    nn.Linear(32, 1)
)

optimizer = torch.optim.Adam(model.parameters())
loss = nn.MSELoss()
# generate data
r = torch.linspace(0.95, 3, 50, requires_grad=True).view(-1, 1)
E = 1 / r
F = -grad(E.sum(), r)[0]

inputs = r

for epoch in range(10**3):
    E_pred = model.forward(inputs)
    F_pred = -grad(E_pred.sum(), r, create_graph=True, retain_graph=True)[0]

    optimizer.zero_grad()
    error = loss(E_pred, E.data) + loss(F_pred, F.data)
    error.backward()
    optimizer.step()

However, if I change the inputs = r to inputs = 1*r, the training loop breaks and gives the following error

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

Could you please explain why this happens?

I’m surprised that the view doesn’t cause trouble to you, too.

When you do an assignment inputs = r, the variable on the left hand side is assigned exactly what is the right hand side, in this case the very object (i.e. the tensor Tensor) that is r.
This contrasts with 1 * r, which is a computed result, even if the computation is not very interesting, and then might get assigned to inputs.
Similarly, the .view should be “computation” to autograd.

Now, if inputs is an “interior” node, i.e. the result of computation, autograd wants to go back through the computational graph that produced inputs. In the second iteration, it will cause the problem because you already did the first time.

The fix is to only keep leaves (i.e. non-computed values, like the ones you pass to the optimizer) or non-grad-requiring tensors outside the loop and to all computation from them within the for loop.
For some cases (e.g. the view), it may be beneficial to not use requires_grad in the factory method (linspace) but instead add .requires_grad_() after the view. This way, r will be a leaf again.

Best regards

Thomas