RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [64, 1]], which is output 0 of AsStridedBackward0, is at version 4; expected version 2 instead

Hi, I’m training a model with a version of the ppo algorithm and get the following error:

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [64, 1]], which is output 0 of AsStridedBackward0, is at version 4; expected version 2 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

I’ve tried most usual approaches of using .clone() or trying different versions of F.relu(), but it’s not doing the trick.

Here is the network i’m using:

class CNN(torch.nn.Module):
  def __init__(self):
    super(CNN, self).__init__()
    
    # Definimos capas (automáticamente se registran como parametros)
    # Capas compartidas
    self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=(1,1), stride=1, padding = 3, bias = True)
    self.max_p1 = nn.MaxPool2d(2, stride=1)
    self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(2,2), stride=1, padding = 3, bias = True)
    self.max_p2 = nn.MaxPool2d(2, stride=2)
    self.conv3 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(5,5), padding = 1, bias = True)
    self.max_p3 = nn.MaxPool2d(2, stride=3)

    # FCN para policy
    self.linear1 = nn.Linear(1024, 256)
    self.linear2 = nn.Linear(256, 64)
    self.linear3 = nn.Linear(64, 16)
    # FCN para value
    self.linear_1 = nn.Linear(1024, 256)
    self.linear_2 = nn.Linear(256, 64)
    self.linear_3 = nn.Linear(64, 1)

    # Inicializamos los parametros de la red:
    nn.init.kaiming_uniform_(self.conv1.weight, mode='fan_in', nonlinearity='relu')
    nn.init.kaiming_uniform_(self.conv2.weight, mode='fan_in', nonlinearity='relu')
    nn.init.kaiming_uniform_(self.conv3.weight, mode='fan_in', nonlinearity='relu')
    nn.init.kaiming_uniform_(self.linear1.weight, mode='fan_in', nonlinearity='relu')
    nn.init.kaiming_uniform_(self.linear2.weight, mode='fan_in', nonlinearity='relu')
    nn.init.kaiming_uniform_(self.linear3.weight, mode='fan_in', nonlinearity='relu')
    nn.init.kaiming_uniform_(self.linear_1.weight, mode='fan_in', nonlinearity='relu')
    nn.init.kaiming_uniform_(self.linear_2.weight, mode='fan_in', nonlinearity='relu')
    nn.init.kaiming_uniform_(self.linear_3.weight, mode='fan_in', nonlinearity='relu')

  # Computa la pasada hacia adelante
  def forward(self, x):
    if len(x.shape) == 3:
      x = x.unsqueeze(0)
  # Forward común
    u1 = self.conv1(x)
    h1 = F.relu(u1)
    f1 = self.max_p1(h1)
    u2 = self.conv2(f1)
    h2 = F.relu(u2)
    f2 = self.max_p2(h2)
    u3 = self.conv3(f2)
    h3 = F.relu(u3)
    f3 = self.max_p3(h3)
    m = torch.flatten(input = f3, start_dim=1)
    # Forward Policy
    u3 = self.linear1(m)
    h3 = F.relu(u3)
    u4 = self.linear2(h3)
    h4 = F.relu(u4)
    u5 = self.linear3(h4)
    y_pred = F.softmax(u5)
    # Forward value
    u_3 = self.linear_1(m)
    h_3 = F.relu(u_3)
    u_4 = self.linear_2(h_3)
    h_4 = F.relu(u_4)
    value_pred = self.linear_3(h_4)
    return y_pred, value_pred

And here is the training loop:

torch.autograd.set_detect_anomaly(True)
env = FireGrid_V4(20, burn_value=10, n_sims=50)
net = CNN()
gamma = 0.99
alpha = 1e-4
clip = 0.2
optimizer = AdamW(net.parameters(), lr = alpha)
stats = {"Actor Loss": [], "Critic Loss": [], "Returns": []}
step_data = []
for episode in tqdm(range(1, 10 + 1)):
    state = env.reset()
    done = False
    ep_return  = 0
    I = 1.
    while not done:
        state_c = state.clone()
        policy, value = net.forward(state_c)
        action = policy.multinomial(1)
        next_state, reward, done = env.step(action.detach())
        next_state_c = next_state.clone()
        _, value_next_state = net.forward(next_state_c)
        I *= gamma
        step_data.append([state_c, action.clone(), reward.clone(), policy.clone(), value.clone(), value_next_state.clone(), I])
        state = next_state
        ep_return += reward
    data = DataLoader(step_data, 100, shuffle=False)
    if episode % 5 == 0:
        for e in range(10):
            print(e)
            for state_t, action_t, reward_t, policy_t, value_t, value_next_state_t, discounts in data:
              net.zero_grad()
              target = reward_t + gamma * value_next_state_t.clone()
              critic_loss = F.mse_loss(value_t.clone().squeeze(), target)
              advantage = (target - value_t.clone()).squeeze()
              new_probs, _ = net.forward(state_t)
              new_log_probs = torch.log(new_probs.squeeze() + 1e-6)
              log_probs = torch.log(policy_t.squeeze() + 1e-6)
              action_log_probs = log_probs.gather(1, action_t.squeeze().unsqueeze(1))
              new_action_log_probs = new_log_probs.gather(1, action_t.squeeze().unsqueeze(1))
              prob_ratio = torch.exp(new_action_log_probs)/torch.exp(action_log_probs)
              weighted_probs = prob_ratio * advantage
              weighted_clipped_probs = torch.clamp(prob_ratio, 1-clip, 1+clip)*advantage
              entropy = -torch.sum(policy_t.squeeze() * log_probs.squeeze(), dim = -1, keepdim = True)
              actor_loss = torch.sum(- discounts * torch.minimum(weighted_probs, weighted_clipped_probs) - 0.02*entropy)
              total_loss = actor_loss + critic_loss
              total_loss.backward(retain_graph = True)
              optimizer.step()
        step_data = []

The firsth batch is correctly backproped and optimizer.step() works, but then with the second batch it raises the error :frowning:
Any ideas? Thanks in advance!

Usually using retain_graph=True is wrong, could raise the “inplace” error, and is often used as a workaround for another error.
Could you describe why you are using it and check if it’s really needed?

Hi!
It isn’t needed, I removed it but I get the same error :frowning:

OK, good to hear it’s not needed. The only inplace operation I’m seeing is:

I *= gamma

Could you replace it with its out-of-place version:

I = I * gamma

and see if this would help?

Hi!
replaced it by your sugestion but exactly the same error message is risen :frowning:

In that case, could you post a minimal and executable code snippet which would reproduce the issue so that we could try to debug it, please?

Sure! The code’s logic is quite different, but the full one includes a simulator that’s to large to share. In any case the same error is raised so it should do.

Here is the cnn I shared before:

import torch
import torch.nn.functional as F
import torch.nn as nn
import numpy as np
import sys
import time
from enviroment.firegrid_v4 import FireGrid_V4
import copy
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from torch.optim import AdamW
import datetime
from tqdm import tqdm
# Red estilo pytorch
class CNN(torch.nn.Module):
  def __init__(self):
    super(CNN, self).__init__()
    
    # Definimos capas (automáticamente se registran como parametros)
    # Capas compartidas
    self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=(1,1), stride=1, padding = 3, bias = True)
    self.max_p1 = nn.MaxPool2d(2, stride=1)
    self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(2,2), stride=1, padding = 3, bias = True)
    self.max_p2 = nn.MaxPool2d(2, stride=2)
    self.conv3 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(5,5), padding = 1, bias = True)
    self.max_p3 = nn.MaxPool2d(2, stride=3)

    # FCN para policy
    self.linear1 = nn.Linear(1024, 256)
    self.linear2 = nn.Linear(256, 64)
    self.linear3 = nn.Linear(64, 16)
    # FCN para value
    self.linear_1 = nn.Linear(1024, 256)
    self.linear_2 = nn.Linear(256, 64)
    self.linear_3 = nn.Linear(64, 1)

    # Inicializamos los parametros de la red:
    nn.init.kaiming_uniform_(self.conv1.weight, mode='fan_in', nonlinearity='relu')
    nn.init.kaiming_uniform_(self.conv2.weight, mode='fan_in', nonlinearity='relu')
    nn.init.kaiming_uniform_(self.conv3.weight, mode='fan_in', nonlinearity='relu')
    nn.init.kaiming_uniform_(self.linear1.weight, mode='fan_in', nonlinearity='relu')
    nn.init.kaiming_uniform_(self.linear2.weight, mode='fan_in', nonlinearity='relu')
    nn.init.kaiming_uniform_(self.linear3.weight, mode='fan_in', nonlinearity='relu')
    nn.init.kaiming_uniform_(self.linear_1.weight, mode='fan_in', nonlinearity='relu')
    nn.init.kaiming_uniform_(self.linear_2.weight, mode='fan_in', nonlinearity='relu')
    nn.init.kaiming_uniform_(self.linear_3.weight, mode='fan_in', nonlinearity='relu')

  # Computa la pasada hacia adelante
  def forward(self, x):
    if len(x.shape) == 3:
      x = x.unsqueeze(0)
  # Forward común
    u1 = self.conv1(x)
    h1 = F.relu(u1)
    f1 = self.max_p1(h1)
    # print(f1.shape)
    u2 = self.conv2(f1)
    h2 = F.relu(u2)
    f2 = self.max_p2(h2)
    # print(f2.shape)
    u3 = self.conv3(f2)
    h3 = F.relu(u3)
    f3 = self.max_p3(h3)
    # print(f3.shape)
    m = torch.flatten(input = f3, start_dim=1)
    # Forward Policy
    u3 = self.linear1(m)
    h3 = F.relu(u3)
    u4 = self.linear2(h3)
    h4 = F.relu(u4)
    u5 = self.linear3(h4)
    y_pred = F.softmax(u5)
    # Forward value
    u_3 = self.linear_1(m)
    h_3 = F.relu(u_3)
    u_4 = self.linear_2(h_3)
    h_4 = F.relu(u_4)
    value_pred = self.linear_3(h_4)
    return y_pred, value_pred

And here is the modified training loop that generates the error:

def env2_step(step):
  state = torch.rand(3, 20, 20)
  reward = torch.rand(1)
  if step == 100:
    return state, reward, True
  return state, reward, False

torch.autograd.set_detect_anomaly(True)

net = CNN()
gamma = 0.99
alpha = 1e-4
clip = 0.2
optimizer = AdamW(net.parameters(), lr = alpha)
stats = {"Actor Loss": [], "Critic Loss": [], "Returns": []}
step_data = []
for episode in tqdm(range(1, 10 + 1)):
    state, _, _ = env2_step(0)
    done = False
    ep_return  = 0
    I = 1.
    step = 1
    while not done:
        state_c = state.clone()
        policy, value = net.forward(state_c)
        action = policy.multinomial(1)
        next_state, reward, done = env2_step(step)
        next_state_c = next_state.clone()
        _, value_next_state = net.forward(next_state_c)
        I = I * gamma
        step_data.append([state_c, action.clone(), reward.clone(), policy.clone(), value.clone(), value_next_state.clone(), I])
        state = next_state
        ep_return += reward
        step = step + 1 
    data = DataLoader(step_data, 100, shuffle=False)
    if episode % 5 == 0:
        for e in range(10):
            print(e)
            for state_t, action_t, reward_t, policy_t, value_t, value_next_state_t, discounts in data:
              net.zero_grad()
              target = reward_t + gamma * value_next_state_t
              critic_loss = F.mse_loss(value_t, target)
              advantage = (target - value_t.clone()).squeeze(1)
              new_probs, _ = net.forward(state_t)
              new_log_probs = torch.log(new_probs + 1e-6)
              log_probs = torch.log(policy_t + 1e-6)
              action_log_probs = log_probs.gather(2, action_t).squeeze(1)
              new_action_log_probs = new_log_probs.gather(1, action_t.squeeze(1))
              prob_ratio = torch.exp(new_action_log_probs)/torch.exp(action_log_probs)
              weighted_probs = prob_ratio * advantage
              weighted_clipped_probs = torch.clamp(prob_ratio, 1-clip, 1+clip)*advantage
              entropy = -torch.sum(policy_t * log_probs, dim = -1, keepdim = True).squeeze(1)
              actor_loss = torch.sum(- discounts * torch.minimum(weighted_probs, weighted_clipped_probs) - 0.02*entropy)
              total_loss = actor_loss + critic_loss
              total_loss.backward()
              optimizer.step()
        step_data = []

Thanks in advance! :slight_smile:

Based on your code snippet the error disappears when policy_t, value, and value_next_state are detached.
All these tensors are created from the model in a while loop where the inputs also have dependencies on the previous iteration in a recursive manner. After this while loop was performed for 5 episodes you are then trying to update the model again in a loop which unfortunately fails.
I didn’t spend enough time on your code to fully figure out the dependency between all tensors, but I hope you can use the “working” (but potentially wrong) approach to check where the needed activations were manipulated inplace.

1 Like