Why is `retain_graph=True` needed in some case but not in other

Everywhere I read, people say just pass retain_graph=True to solve my issue, but I’d like to know what is under the hood. For example, the loss.backward() below doesn’t need to pass in any retain_graph

class Mnist_Logistic(nn.Module):
    def __init__(self):
        super().__init__()
        self.lin = nn.Linear(784, 10)

    def forward(self, xb):
        return self.lin(xb)

for epoch in range(epochs):
    pred = model(input)
    loss = loss_func(pred, output)
    loss.backward()

For my model below, pytorch complains that it deletes the computational graph already, so if I run backward() in the next epoch, it could not access states variable. It seems that it can in the code above.

class Ode(nn.Module):
    def __init__(self, len_data, alpha=0.57, beta=0.11, ...):
        super().__init__()
        self.I, self.E, self.H = torch.zeros(1), torch.zeros(1), torch.zeros(1)
        self.S = nn.Parameter(torch.tensor([0.5]).to(device))
        self.sigma = nn.Parameter(torch.tensor([sigma])).to(device)
        # many nn.Parameter later

    def f(self, t, y):
        # only + and * parameters together

    def forward(self, I0, E0, H0):
        return torchdiffeq.odeint(self.f, t=time_range, y0=self.y0, method='rk4')


for data, y_exact in tqdm(train_dataloader):
    optimizer.zero_grad()    
    y_approx = model(data[0,0,0], data[0,0,1], data[0,0,2])
    loss = loss_fun(y_approx[1:, [1, 6, 7]], y_exact.squeeze())
    loss.backward()
    optimizer.step()
    scheduler.step()

Hi,
PyTorch uses an aggressive memory freeing mechanism wherein as soon as a .backward() call is made on a tensor, PyTorch frees the “references to the saved tensors” in the computation graph associated with this tensor (on which .backward() is called).

These saved tensors are required to calculate the gradient of the tensor with respect to some other tensors, and hence a second backward call on the same tensor will produce an error as pytorch no longer has access to these saved tensors required for grad calculation.

retain_graph=True causes pytorch not to free these references to the saved tensors.

So, in the first code that you posted, each time the for loop for training is run, a new computation graph is created - PyTorch uses dynamic graphs. This new graph saves references to tensors it’ll require for gradient computation. Hence, there’s no need to use retain_graph=True in this one.

As for the second code, please post the part where the .backward() calls are being made.

2 Likes

Hi Srishti, thank you for your answer. I edited the question

Thanks for editing.
But, could you please post a minimal executable snippet that would reproduce the error that you are facing.

From just skimming through this (second Model’s) code, I cannot see why this error would occur.

Anyway, were you able to understand the concept that I explained in my previous reply regarding what’s under the hood?

2 Likes

Which is unfortunately generally wrong.
Do not pass retain_graph=True to any backward call unless you explicitly need it and can explain why it’s needed for your use case.
Usually, it’s used as a workaround which will cause other issues afterwards. The mechanics of this argument were explained well by @srishti-git1110.

1 Like

I managed to created an MRE like below. The backward() error is gone, but now I faced with an error that all the param.grad are 0, which I don’t know how to debug. Do you have any idea?

import torch
import torch.nn as nn
# this lib acts like scipy.integrate.ode but compatible with Torch tensors
from torchdiffeq.torchdiffeq import odeint   

class Ode(nn.Module):
    def __init__(self, beta=0.11, gamma=0.456):
        super().__init__()
        torch.set_default_dtype(torch.float64)
        self.I,  self.R = torch.zeros(1), torch.zeros(1)
        self.S = nn.Parameter(torch.tensor([0.5]))  # S is a hidden param
        self.y0 = torch.tensor([self.S, self.I, self.R])
        self.beta = nn.Parameter(torch.tensor([beta]))
        self.gamma = nn.Parameter(torch.tensor([gamma]))
        self.len_data = 2

    def f(self, t, y):
        S, I, R = y
        N = S + I + R
        return torch.tensor([S * I * self.beta / N,  # S
                             self.beta*I*S / N - self.gamma*I,   # I
                             self.gamma*I   # R
                             ], dtype=torch.float64, requires_grad=True)

    def forward(self, I0, R0):
        self.y0[1] = torch.Tensor([I0])
        self.y0[2] = torch.Tensor([R0])
        time_range = torch.linspace(0, self.len_data, self.len_data + 1)
        return odeint(self.f, t=time_range, y0=self.y0, method='rk4').double()


input = torch.DoubleTensor([[2,3], [4,5], [5,6]])
expected = torch.DoubleTensor([[[4, 5], [5, 6]], [[5, 6], [7, 8]], [[7, 8], [9, 10]]])
loss_fun = torch.nn.MSELoss()
model = Ode()
optimizer = torch.optim.Adam(model.parameters(), 1e-2)

for i in range(len(input)):
    optimizer.zero_grad()
    approx = model(input[i][0], input[i][1])
    observable_vars = approx[1:, [1, 2]]
    loss = loss_fun(observable_vars, expected[i])
    loss.backward()
    for param in optimizer.param_groups[0]['params']:
        if param.requires_grad:
            print(param.grad)
    optimizer.step()

Re-wrapping a tensor will detach the computation graph and create a new leaf tensor:

        return torch.tensor([S * I * self.beta / N,  # S
                             self.beta*I*S / N - self.gamma*I,   # I
                             self.gamma*I   # R
                             ], dtype=torch.float64, requires_grad=True)

Are you seeing any issues if you are returning the result directly without re-creating a new tensor?

1 Like

The only issue is I didn’t know that creating a tensor will detach it from the computational graph :see_no_evil:. Thank you

I changed that snippet to this

return torch.cat([S * I * self.beta / N,  # S
                             self.beta*I*S / N - self.gamma*I,   # I
                             self.gamma*I   # R
                         ])

Now

    optimizer.zero_grad()
    ...
    for name, param in model.named_parameters():
        print(name, param.requires_grad, param.grad)
    print(loss)
    ...
    optimizer.step()

will print

S True None
beta True tensor([-0.5961])
gamma True tensor([2.9083])
tensor(7.4195, grad_fn=<MseLossBackward0>)
S True None
beta True tensor([-0.6962])
gamma True tensor([11.5022])
tensor(8.7388, grad_fn=<MseLossBackward0>)
S True None
beta True tensor([-1.0248])
gamma True tensor([17.9899])
tensor(15.6819, grad_fn=<MseLossBackward0>)

I still don’t get why S.grad is None

Thank you for the explanation. Yes, I was able to understand the mechanism of retain_graph=True. What you said, and I didn’t know, is that

So, in the first code that you posted, each time the for loop for training is run, a new computation graph is created - PyTorch uses dynamic graphs.

Does that mean that since the first code use some pytorch layers, they got that recreation logic baked in?

It only means that each time an iteration of the training loop is run, a fresh graph is created for the loss tensor (i.e. after the backward call on it, the previous graph’s buffers are freed) - this is called dynamic graph creation, in the sense that graphs get created from scratch and on the go each time.

And since, a new graph is getting created after every backward call, there’s no need to retain the previous graph.

1 Like

I haven’t used odeint ever, but from inspecting your code, I think the issue might be the loss tensor not being a function of self.S, which means self.S might not be getting used anywhere in the calculation of loss.

This could be a possible cause for self.S 's grad attribute being None after the backward call on loss.

The reason why I think so is that I can see self.beta and self.gamma being used in the method f to do some calculations.

But, I cannot see self.S being used anywhere that might be causing the final tensor loss to be free from self.S.

Not very sure though. I might be wrong.

1 Like

Thank you both. I used your hints and fixed the problem. Now all parameters have gradients flow through them. Since I can only set one solution, I don’t know which post to set, please let me know :smiley:. Right now I am just giving all of them a heart.