Reinforcement Learning: RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed)

I am trying to write my own actor critic algorithm. Unlike other implementations, I tried to keep a separate actor and critic network.

The problem arises somewhere in my actor or critic loss function

I found a similar question here, but that doesn’t solve my problem - Backward error, although there are two different networks for actor and critic. (PPO implementation)

The error is originating here -

advantage = nrml_disc_rewards-values

critic_loss = advantage.pow(2).mean()
actor_loss = -(torch.sum(torch.log(prob_batch)*advantage))
policy_opt.zero_grad()
actor_loss.backward() policy_opt.step()
value_opt.zero_grad()
critic_loss.backward() value_opt.step()

This is the full traceback -

D:\q_learning\actor_critic.py:90: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at C:\cb\pytorch_1000000000000\work\torch\csrc\utils\tensor_new.cpp:233.)
  state_batch = torch.Tensor([s for (s,a,r, ns) in transitions]).to(device)
Traceback (most recent call last):
  File "D:\q_learning\actor_critic.py", line 112, in <module>
    critic_loss.backward()
  File "C:\Users\anaconda3\envs\torch_2\lib\site-packages\torch\_tensor.py", line 487, in backward
    torch.autograd.backward(
  File "C:\Users\anaconda3\envs\torch_2\lib\site-packages\torch\autograd\__init__.py", line 197, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

Process finished with exit code 1

Here is my code -

#Modified this code - https://github.com/DeepReinforcementLearning/DeepReinforcementLearningInAction/blob/master/Chapter%204/Ch4_book.ipynb
Also, modified this code - https://github.com/higgsfield/RL-Adventure-2/blob/master/1.actor-critic.ipynb
import numpy as np import gym import torch from torch import nn import matplotlib.pyplot as plt env = gym.make('CartPole-v0') device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") learning_rate = 0.0001 episodes = 10000
def discount_rewards(reward, gamma = 0.99): return torch.pow(gamma, torch.arange(len(reward)))*reward def normalize_rewards(disc_reward): return disc_reward/(disc_reward.max())
class Actor(nn.Module):
def init(self, state_size, action_size): super(Actor, self).init() self.state_size = state_size self.action_size = action_size self.linear_relu_stack = nn.Sequential( nn.Linear(state_size, 300), nn.ReLU(), nn.Linear(300, 128), nn.ReLU(), nn.Linear(128, 128), nn.ReLU(), nn.Linear(128, action_size), nn.Softmax() ) def forward(self,x): x = self.linear_relu_stack(x) return x
class Critic(nn.Module):
def init(self, state_size, action_size): super(Critic, self).init() self.state_size = state_size self.action_size = action_size self.linear_stack = nn.Sequential( nn.Linear(state_size, 300), nn.ReLU(), nn.Linear(300, 128), nn.ReLU(), nn.Linear(128, 128), nn.ReLU(), nn.Linear(128, 1) )
def forward(self, x):
    x = self.linear_stack(x)
return x
actor = Actor(env.observation_space.shape[0], env.action_space.n).to(device)
critic = Critic(env.observation_space.shape[0], env.action_space.n).to(device) policy_opt = torch.optim.Adam(params = actor.parameters(), lr = learning_rate) value_opt = torch.optim.Adam(params = critic.parameters(), lr = learning_rate)
score = []
for i in range(episodes): print("i = ", i) state = env.reset() done = False transitions = []
   tot_rewards = 0
while not done:
       value = critic(torch.from_numpy(state).to(device))
    policy = actor(torch.from_numpy(state).to(device))
    action = np.random.choice(np.array([0, 1]), p=policy.cpu().data.numpy())
    next_state, reward, done, info = env.step(action)
    tot_rewards += 1
transitions.append((state, action, tot_rewards, next_state)) state = next_state
if i%50==0:
print("i = ", i, ",reward = ", tot_rewards) score.append(tot_rewards) reward_batch = torch.Tensor([r for (s,a,r, ns) in transitions]).flip(dims = (0,))
   disc_rewards = discount_rewards(reward_batch)
nrml_disc_rewards = normalize_rewards(disc_rewards).to(device)
state_batch = torch.Tensor([s for (s,a,r, ns) in transitions]).to(device)
action_batch = torch.Tensor([a for (s,a,r, ns) in transitions]).to(device)
next_state_batch = torch.Tensor([ns for (s,a,r, ns) in transitions]).to(device)
print("state_batch = ", state_batch.shape)
pred_batch = actor(state_batch) prob_batch = pred_batch.gather(dim=1, index=action_batch.long().view(-1, 1)).squeeze() values = critic(state_batch).squeeze()
value_next = critic(torch.from_numpy(next_state_batch).to(device))
advantage = nrml_disc_rewards-values
   critic_loss = advantage.pow(2).mean()
actor_loss = -(torch.sum(torch.log(prob_batch)*advantage))

   policy_opt.zero_grad()
actor_loss.backward()
policy_opt.step()

   value_opt.zero_grad()
critic_loss.backward()
value_opt.step()

if i%50==0:
    plt.scatter(np.arange(len(score)), score)
    plt.show(block=False)
    plt.pause(3)
    plt.close()

plt.scatter(np.arange(len(score)), score)
plt.show()

Both losses, critic_loss and actor_loss use the advantage tensor in their computation.
The first actor_loss.backward() call will free the intermediate forward activations stored during the previous forward pass, which will cause critic_loss.backward() to fail since both backward passes depend on the computation graph (and the intermediate activations) attached to advantage.
To solve the issue you could use actor_loss.backward(retain_graph=True) or, if it fits your use case, sum both losses together before calling .backward() on the sum.

1 Like