Here’s the example code:
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self, d_model):
super(Model, self).__init__()
self.linear1 = nn.Linear(d_model, d_model)
self.linear2 = nn.Linear(d_model, 1)
def forward(self, x):
return self.linear1(x)
def calculate(self, x):
return self.linear2(x)
batch_size, d_model, target = 4, 3, 1
data = torch.rand(batch_size, d_model)
model = Model(d_model)
temp = model(data) # (d_model, d_model)
optimizer = torch.optim.Adam(model.parameters())
for i in range(10):
y = model.calculate(temp) # (d_model, 1)
loss = ((y - target) ** 2).mean()
optimizer.zero_grad()
loss.backward()
optimizer.step()
Assume that the forward
of the model is so complex that it may not be computed in each for loop. I get the error RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.
I think backward
destroy the graph of forward
so that it cannot be called in the second and subsequent updates. loss.backward(retain_graph=False)
solves the problem, but is there other efficient way? Any help would be much appreciated.
The intermediate tensors created in forward
are freed in the backward
call and the next iteration will fail since you are only recomputing the calculate
part of your model.
Could you explain your use case a bit more and why forward
is not recomputed anymore?
Thanks for your reply!
I’m applying the Proximal Policy Optimization Algorithms to my problem following the code vwxyzjn/ppo-implementation-details. The alrorithm is shown below:

where the advantage estimates A is
- used to obtain surrogate loss function L;
- calculated by current value function V parameterized by φ;
- repeatedly used in K epochs.
The input of V is usually the observation of an agent (flattened data without gradients), but in my case the input of V is the output of a parameterized deep network used to aggregate features. I have modified some naming of the code to make it clearer.
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self, d_model):
super(Model, self).__init__()
self.linear1 = nn.Linear(d_model, d_model)
self.linear2 = nn.Linear(d_model, 1)
def feature_aggregation(self, x):
return self.linear1(x)
def value_function(self, x):
return self.linear2(x)
batch_size, d_model, target = 4, 3, 1
observation = torch.rand(batch_size, d_model)
model = Model(d_model)
h_observation = model.feature_aggregation(observation) # (d_model, d_model)
optimizer = torch.optim.Adam(model.parameters())
for i in range(10):
value = model.value_function(h_observation) # (d_model, 1)
loss = ((value - target) ** 2).mean()
optimizer.zero_grad()
loss.backward()
optimizer.step()
The feature_aggregation
(i.e. forward
in the original code) is calculated once because I want to save running time. Please let me explain in detail my thoughts, but forgive me for my lack of understanding of the backward
mechanism:
- since the input of
feature_aggregation
is the same for all epochs, I do not want to spend repeated time calculating it, but instead to keep the computational graph that leads to h_observation
;
- whenever
value
is calculated, I want to calculate the computational graph for value
“on the basis of” the previously saved computational graph for h_observation
;
- during
backward
, I wonder if I can free the computational graph of the value_function
part because it is useless and this may save time and memory.
Based on your explanation it seems h_observation
is treated as a static input and you don’t want to optimizer the feature_aggregation
step in each iteration. If so, you could execute this step in a no_grad
guard to avoid creating the computation graph:
with torch.no_grad():
h_observation = model.feature_aggregation(observation)
which will create h_observation
as a tensor without any Autograd history.