Struggling with a Runtime Error related to in-place operations

This is the error I have encountered

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [20, 80]], which is output 0 of TBackward, is at version 2; expected version 1 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

The functionality works with linear layers (basic ppo functionality obtained from https://github.com/higgsfield/RL-Adventure-2), however I have been trying to implement some LSTM layers which required me to modify the inputs to be sequence and some other changes. I’m not sure where the in-place operations are occurring that are messing up the gradient computation.

def add_states(self, actor_, envs, state, idx, entropy):
        '''
        add states to the batch of state, action, rewards ... etc pairs
        
        input(s): state-tensor with current state, current index in the data set, entropy-constant representing entropy
        '''
        state = torch.FloatTensor(state)#.to(device)
        
        if self.type_ == 'LSTM':
            ###create batch dimension
            state = state.unsqueeze(1)
        
        dist = actor_(state)
        value = self.critic_(state)

        action = dist.sample()

        if self.discrete:
            temp = action.numpy().item()
        else:
            temp = F.softmax(action, dim = 1)
            temp = temp.numpy()[0]

        if self.type_ == 'LSTM':
            state = state.squeeze(1)[-1:, :]
            
        reward, next_state = envs.step(temp, state.numpy(), idx)

        log_prob = dist.log_prob(action)

        entropy += dist.entropy().mean()
        
        return log_prob.unsqueeze(1), value, torch.FloatTensor(np.array([reward])).unsqueeze(1), state, action.unsqueeze(1), next_state, entropy 

    
    def train(self, multiple_runs):
        
        self.rewards_all = []
        update_counter = 0
        for i in range(multiple_runs):
            if i == 0:
                state, j = self.envs.initial_state()
            else:
                state, j = self.envs.initial_state(random = True)
            max_idx = self.envs.data_df.shape[0] - j
            idx  = j + 1
        
            while idx < max_idx:
                log_probs = []
                values    = []
                states    = []
                actions   = []
                rewards   = []
                entropy = 0

                # self.actor_.eval()
                # self.critic_.eval()
                # with torch.no_grad():

                if self.type_ == 'LSTM':
                    saved_cell_actor = (self.actor_.hidden_cell[0].clone(),self.actor_.hidden_cell[1].clone())
                    saved_cell_critic = (self.critic_.hidden_cell[0].clone(),self.critic_.hidden_cell[1].clone())

                    self.actor_.hidden_cell = (self.actor_.hidden_cell[0][:, -1:, :].clone(), self.actor_.hidden_cell[1][:, -1:, :].clone())
                    self.critic_.hidden_cell = (self.critic_.hidden_cell[0][:, -1:, :].clone(), self.critic_.hidden_cell[1][:, -1:, :].clone())

                    for k in range(self.seq_len-1):
                        states.append(torch.FloatTensor(state[k:k+1,:]))

                for _ in range(self.num_steps):

                    if self.type_ == 'LSTM':
                        temp_state = state[1:, :]

                    log_prob, value, reward, state, action, next_state, entropy = self.add_states(self.actor_, self.envs, state, idx, entropy)
                    
                    log_probs.append(log_prob)
                    values.append(value)
                    rewards.append(reward)#.to(device))
                    states.append(state)
                    actions.append(action)

                    state = next_state

                    if self.type_ == 'LSTM':
                        state = np.concatenate((temp_state, state), axis = 0)

                    self.rewards_all.append(reward)
                    idx += 1
                    if idx >= max_idx:
                        break
                
                next_state = torch.FloatTensor(next_state)

                if self.type_ == 'LSTM':     
                    next_state = torch.FloatTensor(np.concatenate((state[1:, :], next_state), axis = 0)).unsqueeze(1)#.to(device)

                next_value = self.critic_(next_state)
                next_value = next_value.item()

                returns = self.sf.compute_gae(next_value, rewards, values)
                
                
                returns   = torch.cat(returns).detach()
                log_probs = torch.cat(log_probs).detach()
                values    = torch.cat(values).detach()
                states    = torch.cat(states)
                actions   = torch.cat(actions)
                advantage = returns - values

                if self.type_ == 'LSTM':
                    self.actor_.hidden_cell = saved_cell_actor 
                    self.critic_.hidden_cell = saved_cell_critic

                
                # self.actor_.train()
                # self.critic_.train()

                update_counter += 1
                if update_counter % 3 == 0:
                    print('{} % done.'.format((idx/max_idx)*100))       
                    
                self.sf.ppo_update(states, actions, log_probs, returns, advantage, self.actor_, self.critic_, self.optimizer_actor, self.optimizer_critic)
    def ppo_update(self, states, actions, log_probs, returns, advantages, actor_, critic_, optimizer_actor, optimizer_critic, clip_param=0.2):
        '''
        training loop: iterate through batches of states, actions, returns and advantages 
        '''
        for _ in range(self.ppo_epochs):
            for state, action, old_log_probs, return_, advantage in self.ppo_iter(states, actions, log_probs, returns, advantages):
                #with torch.autograd.set_detect_anomaly(True):
                ###obtain outputs from neural nets
                dist = actor_(state)
                value = critic_(state)

                ###calculate entropy and log_prob
                entropy = dist.entropy().mean()
                new_log_probs = dist.log_prob(action)

                # print(action.size())
                # print(new_log_probs.size())
                # print(advantage.size())

                ###calculate ppo ratio from paper
                ratio = (new_log_probs - old_log_probs).exp()
                surr1 = ratio * advantage
                surr2 = torch.clamp(ratio, 1.0 - clip_param, 1.0 + clip_param) * advantage

                ###actor and critic losses
                actor_loss  = - torch.min(surr1, surr2).mean() - 0.001 * entropy
                mse_loss = nn.MSELoss()
                critic_loss = mse_loss(value, return_) - 0.001 * entropy

                print(actor_loss.grad_fn)
                #actor_loss.register_hook(lambda grad: print(grad))
                ###update
                optimizer_actor.zero_grad()
                optimizer_critic.zero_grad()
                

                actor_loss.backward(retain_graph=True)
                critic_loss.backward(retain_graph=True)
                
                optimizer_actor.step()
                optimizer_critic.step()

   def f(self, temp_tensor):
        '''
        helper function for building out the sequences 
        '''
        temp = [temp_tensor[j:j+self.seq_len, :].clone().unsqueeze(1) for j in range(self.mini_batch_size)]
        # for i in temp:
        #     print(i.size())
        return torch.cat(temp, 1)

    def ppo_iter(self, states, actions, log_probs, returns, advantage):
        '''
        iterate through a set of the data using minibatches
        input(s):
        '''
        batch_size = states.size(0)
        for _ in range(batch_size // self.mini_batch_size):
            #rand_ids = np.random.randint(0, batch_size, mini_batch_size)
            #lump = self.mini_batch_size + self.seq_len 
            rand_ids = np.random.randint(self.seq_len-1, batch_size - self.mini_batch_size)
            #yield states[rand_ids, :], actions[rand_ids, :], log_probs[rand_ids, :], returns[rand_ids, :], advantage[rand_ids, :]
            yield self.f(states[rand_ids-self.seq_len+1:rand_ids+self.mini_batch_size:, :]), actions[rand_ids:rand_ids+self.mini_batch_size, :], log_probs[rand_ids:rand_ids+self.mini_batch_size, :], returns[rand_ids:rand_ids+self.mini_batch_size, :], advantage[rand_ids:rand_ids+self.mini_batch_size, :]
   
def set_up_layers(self, state_dim, action_dim):
        '''
        make structure for the actor network
        '''
        if self.type_ == 'linear':

            layers = base_net(state_dim, self.dim_1, self.dim_2, self.type_, self.dropout)
            layers.append(nn.Linear(self.dim_2, action_dim))
        
            self.model = nn.Sequential(*layers)
        
        elif self.type_ == 'LSTM':
            self.lstm = nn.LSTM(state_dim, hidden_size = self.hidden_dim)
            self.lin_cap = nn.Linear(self.hidden_dim, action_dim)
            

    
    def forward(self, state):
        '''
        forward propagate the network
        input(s): state
        output(s): result of the actor network
        '''
        out = self.model(state)
        if self.discrete:
            probs = F.softmax(out, dim = 1)
            return Categorical(probs)
        else: 
            mu  = self.model(state)
            std   = self.log_std.exp().expand_as(mu)
            dist  = Normal(mu, std)
            return dist
            

class actor_LSTM(actor):
    '''
    LSTM actor class
    need to overwrite the forward function
    '''
    def __init__(self, state_dim, action_dim, params = default_params):
        
        super(actor_LSTM, self).__init__(state_dim, action_dim, params = params)

        self.reset_hidden_cell()

    def reset_hidden_cell(self):
        '''
        resets the hidden cell state, or intializes it 
        '''
        self.hidden_cell = (torch.zeros(1,self.mini_batch_size,self.hidden_dim), torch.zeros(1,self.mini_batch_size,self.hidden_dim))

    def forward(self, x):
        lstm_out, self.hidden_cell = self.lstm(x, self.hidden_cell)
        
        mu = self.lin_cap(lstm_out[-1])

        if self.discrete:
            probs = F.softmax(mu, dim = 1)
            return Categorical(probs)
        else: 
            std   = self.log_std.exp().expand_as(mu)
            dist  = Normal(mu, std)
            return dist

Thanks in advance

Tried this as an update, still same error

def ppo_iter(self, states, actions, log_probs, returns, advantage):
        '''
        iterate through a set of the data using minibatches
        input(s):
        '''
        batch_size = states.size(0) - self.mini_batch_size
        for _ in range(batch_size // self.mini_batch_size):
            #rand_ids = np.random.randint(0, batch_size, mini_batch_size)
            #lump = self.mini_batch_size + self.seq_len 
            rand_ids = np.random.randint(self.mini_batch_size, batch_size - self.mini_batch_size)
            inds = np.array([[(rand_ids - 1) - i - j for i in reversed(range(self.seq_len))] for j in reversed(range(self.mini_batch_size))])
            inds= torch.tensor(inds.T)
            inds = inds.reshape(inds.size(0)*inds.size(1))
            temp = states.unsqueeze(1)
            x = torch.index_select(temp, 0, inds).view(self.seq_len, self.mini_batch_size, -1)
            #yield states[rand_ids, :], actions[rand_ids, :], log_probs[rand_ids, :], returns[rand_ids, :], advantage[rand_ids, :]
            yield x, actions[rand_ids- self.mini_batch_size:rand_ids, :], log_probs[rand_ids- self.mini_batch_size:rand_ids, :], returns[rand_ids- self.mini_batch_size:rand_ids, :], advantage[rand_ids- self.mini_batch_size:rand_ids, :]

I cannot find the inplace operation, which might raise this error, in your code.
Could you add the class definition and some random tensors, so that I could execute and debug the code, please?

Could this be a potential clue?

@ptrblck Thanks for taking a look at it I made a super simple environment class so everything can run together. The simple environment class still works with linear layers, just not LSTM.

Simple environment and some parameters

default_params = {
    'discrete' : False,
    'lr' : 3e-4,
    'num_steps' : 300,
    'mini_batch_size' : 30,
    'ppo_epochs' : 6,
    'action_dim' : 2,
    'seq_len' : 10,
    'type_' : 'LSTM',
    
    'dropout' : False,
    'discrete' : False,
    'std' : 0.0,
    'hidden_dim' : 20,
    'dim_1' : 64,
    'dim_2' : 32
    
}


class debug_env():

	def __init__(self, params = default_parameters):

		self.action_dim = 2
		self.state_dim = 4
		self.n_pts = 50000
		self.seq_len = params['seq_len']
		self.data_df = self.build()

	def build(self):
		return pd.DataFrame(np.random.normal(size=(self.n_pts, self.state_dim)))

	def initial_state(self):
		i = 0
		return self.data_df.iloc[i:i+self.seq_len].values, i +self.seq_len - 1

	def step(self, action, state, idx):
		return self.rewards(), self.data_df.iloc[idx].values.reshape(1, -1)
	
	def rewards(self):
		return np.random.normal()

Training Loop and What not

###used for single actor###
class trainer_tester():
    '''
    class for implementing training/testing functionality
    '''
    def __init__(self, envs, params = default_parameters_single_actor):
        ###make the environment
        self.envs = envs
        self.discrete = params['discrete']
        
        ### intialize some hyper parameters
        self.state_dim = envs.state_dim 
        self.action_dim = envs.action_dim

        ###hyper params:
        self.lr               = params['lr']
        self.num_steps        = params['num_steps']
        self.mini_batch_size  = params['mini_batch_size']
        self.ppo_epochs       = params['ppo_epochs']
        self.seq_len          = envs.seq_len
        self.type_            = params['type_']


        if self.type_ == 'linear':
            self.sf = training_functionality(params)
            ###intialize actors and critics
            self.actor_ = actor(self.state_dim, self.action_dim, params = params)
            self.critic_ = critic(self.state_dim, params = params)

        elif self.type_ == 'LSTM':
            self.sf = training_functionality_LSTM(params)
            ###intialize actors and critics
            self.actor_ = actor_LSTM(self.state_dim, self.action_dim, params = params)
            self.critic_ = critic_LSTM(self.state_dim, params = params)
            
        self.optimizer_actor = torch.optim.Adam(self.actor_.parameters(), lr = self.lr)
        self.optimizer_critic = torch.optim.Adam(self.critic_.parameters(), lr = self.lr)
       
    def add_states(self, actor_, envs, state, idx, entropy):
        '''
        add states to the batch of state, action, rewards ... etc pairs
        
        input(s): state-tensor with current state, current index in the data set, entropy-constant representing entropy
        '''
        state = torch.FloatTensor(state)#.to(device)
        
        if self.type_ == 'LSTM':
            ###create batch dimension
            state = state.unsqueeze(1)
        
        dist = actor_(state)
        value = self.critic_(state)

        action = dist.sample()

        if self.discrete:
            temp = action.numpy().item()
        else:
            temp = F.softmax(action, dim = 1)
            temp = temp.numpy()[0]

        if self.type_ == 'LSTM':
            state = torch.index_select(state, 0, torch.tensor([state.size(0) -1])).squeeze(1)
            
        reward, next_state = envs.step(temp, state.numpy(), idx)

        log_prob = dist.log_prob(action)

        entropy += dist.entropy().mean()
        
        return log_prob.unsqueeze(1), value, torch.FloatTensor(np.array([reward])).unsqueeze(1), state, action.unsqueeze(1), next_state, entropy 

    
    def train(self, multiple_runs):
        
        self.rewards_all = []
        update_counter = 0
        for i in range(multiple_runs):
            if i == 0:
                state, j = self.envs.initial_state()
            else:
                state, j = self.envs.initial_state(random = True)
            max_idx = self.envs.data_df.shape[0] - j
            idx  = j + 1
        
            while idx < max_idx:
                log_probs = []
                values    = []
                states    = []
                actions   = []
                rewards   = []
                entropy = 0

                # self.actor_.eval()
                # self.critic_.eval()
                # with torch.no_grad():

                if self.type_ == 'LSTM':
                    saved_cell_actor = (self.actor_.hidden_cell[0].clone(),self.actor_.hidden_cell[1].clone())
                    saved_cell_critic = (self.critic_.hidden_cell[0].clone(),self.critic_.hidden_cell[1].clone())

                    self.actor_.hidden_cell = (self.actor_.hidden_cell[0][:, -1:, :].clone(), self.actor_.hidden_cell[1][:, -1:, :].clone())
                    self.critic_.hidden_cell = (self.critic_.hidden_cell[0][:, -1:, :].clone(), self.critic_.hidden_cell[1][:, -1:, :].clone())

                    # self.actor_.reset_hidden_cell(1)
                    # self.critic_.reset_hidden_cell(1)

                    for k in range(self.seq_len-1):
                        states.append(torch.FloatTensor(state[k:k+1,:]))

                for _ in range(self.num_steps):

                    if self.type_ == 'LSTM':
                        temp_state = state[1:, :]

                    log_prob, value, reward, state, action, next_state, entropy = self.add_states(self.actor_, self.envs, state, idx, entropy)
                    
                    log_probs.append(log_prob)
                    values.append(value)
                    rewards.append(reward)#.to(device))
                    states.append(state)
                    actions.append(action)

                    state = next_state


                    if self.type_ == 'LSTM':
                        state = np.concatenate((temp_state, state), axis = 0)

                    self.rewards_all.append(reward)

                    idx += 1
                    if idx >= max_idx:
                        break

                next_state = torch.FloatTensor(next_state)



                if self.type_ == 'LSTM':     
                    next_state = torch.FloatTensor(np.concatenate((state[1:, :], next_state), axis = 0)).unsqueeze(1)#.to(device)

                next_value = self.critic_(next_state)
                next_value = next_value.item()

                returns = self.sf.compute_gae(next_value, rewards, values)
                
                
                returns   = torch.cat(returns).detach()
                log_probs = torch.cat(log_probs).detach()
                values    = torch.cat(values).detach()
                states    = torch.cat(states)
                actions   = torch.cat(actions)
                advantage = returns - values

                if self.type_ == 'LSTM':
                    # self.actor_.reset_hidden_cell(self.mini_batch_size)
                    # self.critic_.reset_hidden_cell(self.mini_batch_size)
                    self.actor_.hidden_cell = saved_cell_actor 
                    self.critic_.hidden_cell = saved_cell_critic

                
                # self.actor_.train()
                # self.critic_.train()

                update_counter += 1
                if update_counter % 3 == 0:
                    print('{} % done.'.format((idx/max_idx)*100))       
                    
                self.sf.ppo_update(states, actions, log_probs, returns, advantage, self.actor_, self.critic_, self.optimizer_actor, self.optimizer_critic)
                
    def test(self, envs):
        '''
        Implementation to test some of the data without updating any parameters of the networks
        '''
        self.envs = envs
        self.rewards_all_test = []
        state, j = self.envs.initial_state()
        max_idx = self.envs.data_df.shape[0] - j
        idx  = j + 1
        rewards_all = []
        entropy = 0
        while idx < max_idx:
            log_prob, value, reward, state, action, next_state, entropy = self.add_states(self.actor_, envs, state, idx, entropy)
            state = next_state
            self.rewards_all_test.append(reward)
            idx += 1
            if idx >= max_idx:
                break

Supplementary functionality for training

class training_functionality():
    '''
    wrapper class for implementing training supplementary functions
    '''
    def __init__(self, params):
        self.mini_batch_size = params['mini_batch_size']
        self.ppo_epochs      = params['ppo_epochs']
        self.seq_len         = params['seq_len']


    def compute_gae(self, next_value, rewards, values, gamma=0.99, tau=0.95):
        '''
        General Advantage Estimation function from paper: used to estimate the advatage the actor gains 
            from a particular action
        input(s):
        '''
        values = values + [next_value]
        gae = 0
        returns = []
        for step in reversed(range(len(rewards))):
            delta = rewards[step] + gamma * values[step + 1]  - values[step]
            gae = delta + gamma * tau * gae
            returns.insert(0, gae + values[step])
        return returns


    def ppo_iter(self, states, actions, log_probs, returns, advantage):
        '''
        iterate through a set of the data using minibatches
        input(s):
        '''

        batch_size = states.size(0)
        for _ in range(batch_size // self.mini_batch_size):
            rand_ids = np.random.randint(0, batch_size, self.mini_batch_size)
            #rand_ids = np.random.randint(0, batch_size - mini_batch_size)
            yield states[rand_ids, :], actions[rand_ids, :], log_probs[rand_ids, :], returns[rand_ids, :], advantage[rand_ids, :]
            #yield states[rand_ids:rand_ids+mini_batch_size:, :], actions[rand_ids:rand_ids+mini_batch_size, :], log_probs[rand_ids:rand_ids+mini_batch_size, :], returns[rand_ids:rand_ids+mini_batch_size, :], advantage[rand_ids:rand_ids+mini_batch_size, :]

    def ppo_update(self, states, actions, log_probs, returns, advantages, actor_, critic_, optimizer_actor, optimizer_critic, clip_param=0.2):
        '''
        training loop: iterate through batches of states, actions, returns and advantages 
        '''
        for _ in range(self.ppo_epochs):
            for state, action, old_log_probs, return_, advantage in self.ppo_iter(states, actions, log_probs, returns, advantages):
                #with torch.autograd.set_detect_anomaly(True):
    
                ###obtain outputs from neural nets
                dist = actor_(state)
                value = critic_(state)

                ###calculate entropy and log_prob
                entropy = dist.entropy().mean()
                new_log_probs = dist.log_prob(action)

                # print(action.size())
                # print(new_log_probs.size())
                # print(advantage.size())

                ###calculate ppo ratio from paper
                ratio = (new_log_probs - old_log_probs).exp()
                surr1 = ratio * advantage
                surr2 = torch.clamp(ratio, 1.0 - clip_param, 1.0 + clip_param) * advantage

                ###actor and critic losses
                actor_loss  = - torch.min(surr1, surr2).mean() - 0.001 * entropy
                mse_loss = nn.MSELoss()
                critic_loss = mse_loss(value, return_) - 0.001 * entropy
        
                ###update
                optimizer_actor.zero_grad()
                optimizer_critic.zero_grad()
                

                actor_loss.backward(retain_graph=True)
                critic_loss.backward(retain_graph=True)
                
                optimizer_actor.step()
                optimizer_critic.step()



class training_functionality_LSTM(training_functionality):
    '''
    class for the LSTM training functionality 
    needed to overwrite a couple of functions
    '''
    def __init__(self, params):
        super(training_functionality_LSTM, self).__init__(params)

    def f(self, temp_tensor):
        ###depreciated
        '''
        helper function for building out the sequences 
        '''
        temp = [temp_tensor[j:j+self.seq_len, :].clone().unsqueeze(1) for j in range(self.mini_batch_size)]
        # for i in temp:
        #     print(i.size())
        return torch.cat(temp, 1)

    def ppo_iter(self, states, actions, log_probs, returns, advantage):
        '''
        iterate through a set of the data using minibatches
        input(s):
        '''
        #with torch.autograd.set_detect_anomaly(True):
        batch_size = states.size(0) - self.mini_batch_size
        for _ in range(batch_size // self.mini_batch_size):
            #rand_ids = np.random.randint(0, batch_size, mini_batch_size)
            #lump = self.mini_batch_size + self.seq_len 
            rand_ids = np.random.randint(self.mini_batch_size+self.seq_len, batch_size)
            #print(rand_ids)
            inds = np.array([[(rand_ids - 1) - i - j for i in reversed(range(self.seq_len))] for j in reversed(range(self.mini_batch_size))])
            inds= torch.tensor(inds.T)
            inds = inds.reshape(inds.size(0)*inds.size(1))
            temp = states.unsqueeze(1)
            #print(inds)
            x = torch.index_select(temp, 0, inds).view(self.seq_len, self.mini_batch_size, -1)
            #yield states[rand_ids, :], actions[rand_ids, :], log_probs[rand_ids, :], returns[rand_ids, :], advantage[rand_ids, :]
            yield x, actions[rand_ids- self.mini_batch_size:rand_ids, :], log_probs[rand_ids- self.mini_batch_size:rand_ids, :], returns[rand_ids- self.mini_batch_size:rand_ids, :], advantage[rand_ids- self.mini_batch_size:rand_ids, :]

network classes


def base_net(state_dim, dim_1, dim_2, type_, drop_out):
    '''
    function to define base_network
    '''
    if type_ == 'linear':
        layers = [
            nn.Linear(state_dim, dim_1),
            nn.Tanh(),
            nn.Linear(dim_1, dim_2),
            nn.Tanh()
        ]
        if drop_out:
            layers.intsert(2, nn.Dropout(p = 0.3))
            layers.append(nn.Dropout(p = 0.3))
        
    
    return layers


###actor build###       
class actor(nn.Module):
    '''
    The actor class
    '''
    def __init__(self, state_dim, action_dim, params = default_params):
        '''
        input(s): dimensions of state and action, and the max action
        '''
        super(actor, self).__init__()
        
        ###for a discrete action space
        self.discrete = params['discrete']
        self.type_ = params['type_']
        self.dropout = params['dropout']
        self.std = params['std']
        self.hidden_dim = params['hidden_dim']
        self.dim_1 = params['dim_1']
        self.dim_2 = params['dim_2']
        self.mini_batch_size = params['mini_batch_size']
        
        self.log_std = nn.Parameter(torch.ones(1, action_dim) * self.std)
        
        
        self.set_up_layers(state_dim, action_dim)
        
        ###apply the intial weights if needed
        #self.apply(init_weights)
    
    def set_up_layers(self, state_dim, action_dim):
        '''
        make structure for the actor network
        '''
        if self.type_ == 'linear':

            layers = base_net(state_dim, self.dim_1, self.dim_2, self.type_, self.dropout)
            layers.append(nn.Linear(self.dim_2, action_dim))
        
            self.model = nn.Sequential(*layers)
        
        elif self.type_ == 'LSTM':
            self.lstm = nn.LSTM(state_dim, hidden_size = self.hidden_dim)
            self.lin_cap = nn.Linear(self.hidden_dim, action_dim)
            

    
    def forward(self, state):
        '''
        forward propagate the network
        input(s): state
        output(s): result of the actor network
        '''
        out = self.model(state)
        if self.discrete:
            probs = F.softmax(out, dim = 1)
            return Categorical(probs)
        else: 
            mu  = self.model(state)
            std   = self.log_std.exp().expand_as(mu)
            dist  = Normal(mu, std)
            return dist
            

###critic build###
class critic(nn.Module):
    '''
    critic class
    '''
    def __init__(self, state_dim, params = default_params):
        '''
        input(s): dimensions of state and action
        '''
        super(critic, self).__init__()
        self.type_ = params['type_']
        self.dropout = params['dropout']
        self.hidden_dim = params['hidden_dim']
        self.dim_1 = params['dim_1']
        self.dim_2 = params['dim_2']
        self.mini_batch_size = params['mini_batch_size']
        
        self.set_up_layers(state_dim)
        
    
    def set_up_layers(self, state_dim):
        '''
        make structure for the actor network
        '''
        if self.type_ == 'linear':
            layers = base_net(state_dim, self.dim_1, self.dim_2, self.type_, self.dropout)
            layers.append(nn.Linear(self.dim_2, 1))
            self.model = nn.Sequential(*layers)

        elif self.type_ == 'LSTM':
            self.lstm = nn.LSTM(state_dim, hidden_size = self.hidden_dim)
            self.lin_cap = nn.Linear(self.hidden_dim, 1)
        
    
    def forward(self, state):
        '''
        forward propagate the network
        input(s): state and action
        output(s): result of the critic network
        '''
        return self.model(state)




class actor_LSTM(actor):
    '''
    LSTM actor class
    need to overwrite the forward function
    '''
    def __init__(self, state_dim, action_dim, params = default_params):
        
        super(actor_LSTM, self).__init__(state_dim, action_dim, params = params)

        self.reset_hidden_cell(self.mini_batch_size)

    def reset_hidden_cell(self, mini_batch_size):
        '''
        resets the hidden cell state, or intializes it 
        '''
        self.hidden_cell = (torch.zeros(1,mini_batch_size,self.hidden_dim), torch.zeros(1,mini_batch_size,self.hidden_dim))

    def forward(self, x):
        
        lstm_out, self.hidden_cell = self.lstm(x, self.hidden_cell)
        mu = self.lin_cap(lstm_out[-1])

        if self.discrete:
            probs = F.softmax(mu, dim = 1)
            return Categorical(probs)
        else: 
            std   = self.log_std.exp().expand_as(mu)
            dist  = Normal(mu, std)
            return dist

class critic_LSTM(critic):
    '''
    LSTM critic class
    need to overwrite the forward function
    '''
    def __init__(self, state_dim, params = default_params):
        
        super(critic_LSTM, self).__init__(state_dim, params = params)
        
        self.reset_hidden_cell(self.mini_batch_size)
    
    def reset_hidden_cell(self, mini_batch_size):
        '''
        resets the hidden cell state, or intializes it 
        '''
        self.hidden_cell = (torch.zeros(1,mini_batch_size,self.hidden_dim), torch.zeros(1,mini_batch_size, self.hidden_dim))
    
    def forward(self, x):
        #.view(len(x) ,1, -1)
        lstm_out, self.hidden_cell = self.lstm(x, self.hidden_cell)
        #lstm_out[-1].clone()
        return self.lin_cap(lstm_out[-1])

calls

de = debug_env()

trainer = trainer_tester(de)
trainer.train(1)

Thanks again for the help, I’ve been stumped for days and have exhausted all the similar errors people have run into.

Thanks for the code.
It’s quite long so I haven’t looked at it deeply enough to understand all nuances of the training.
However, I guess the inplace error might come from these lines of code:

                actor_loss.backward(retain_graph=True)
                critic_loss.backward(retain_graph=True)
                
                optimizer_actor.step()
                optimizer_critic.step()

Here is seems you are retaining the graph and trying to update the “old” parameters multiple times.
I.e. the vanilla workflow would be:

  • model parameters are in state P0
  • execute forward pass
  • execute backward pass and calculate gradients
  • update parameters using their gradients via optimizer.step()

In your code snippet you are using retain_graph=True, which will keep the intermediate activations to calculate the gradients (from previous steps) again.
However, since the optimier.step() calls were already performed, these gradients would be wrong and thus this error is raised.

Here is a small code snippet of what I mean:

model = nn.Linear(1, 1)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
x = torch.randn(1, 1)

out = model(x)
out.backward(retain_graph=True)
optimizer.step() # works

# gradients would be wrong, out is not calculated by the current parameters
out.backward(retain_graph=False) # error
optimizer.step() 

Would your training routine fit into this error description?

So, the retain retain_graph = True is necessary for the first call to backward(), i.e. actor_loss.backward(), and the error is still present if I remove it in the the second call to backward().

The error only occurs when I implement the LSTM functionality. I’m not sure if it works this way, but could it have something to do with the fact that state tensors are used more than once because they are part of a sequence and, therefore, the gradients are accumulating in individual state leaf tensors? However, I still don’t know if that makes sense because it doesn’t seem to have anything to do with inplace operations.

Thanks for your help.

I had some similar issue during the backward() call. I replaced the squeeze() and unsqueeze() with view() calls. It seems that the usage of squeeze() causes this inplace replacement behavior but I can’t actually capture the dynamics of it.

1 Like

Hi, I have been having similar issues/errors (in place error ) when I update the generator twice in a GAN framework.
Kindly help.

model = nn.Linear(1, 1)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
x = torch.randn(1, 1)

out = model(x)
out.backward(retain_graph=True)
optimizer.step() # works

gradients would be wrong, out is not calculated by the current parameters

out.backward(retain_graph=False) # error
optimizer.step()

However, unlike this snippet of code. I do pass the value to the model the second time as well before the backwards() and step()

Could you post a minimal, executable code snippet reproducing the error you are seeing, please?

netG.zero_grad()
label.fill_(real_label) # fake labels are real for generator cost

Since we just updated D, perform another forward pass of all-fake batch through D

output = netD(fake).view(-1).unsqueeze(1)

Calculate G’s loss based on this output

errG = criterion(output, label)

Calculate gradients for G

errG.backward(retain_graph=True)

Update G

optimizerG.step()
optimizerG.zero_grad()

Perform another update on the generator

netG.zero_grad()
label.fill_(real_label)
output = netD(fake)
errG = criterion(output, label)
errG.backward(retain_graph=False)

Update G

optimizerG.step()
optimizerG.zero_grad()
D_G_z2 = output.mean()

The in place error arises when I perform the second update on the generator. Kindly, help

This seems to be exactly the same error which I’m describing in my previous post.
Your code snippet is neither complete nor executable so I have to guess a bit.
Assuming fake is the output of the generator, which is still attached to the computation graph:

  • you are callign errG.backward(retain_graph=True), which will compute the gradients w.r.t the parameters of netG
  • optimizerG.step() updates the parameters of netG, the forward activations are now stale
  • you are calculating output using the same fake tensor and are thus not recreating the computation graph as claimed
  • errG.backward() will try to compute the gradients w.r.t. the parameters in netG and will fail since the activations are stale.

Yes, @ptrblck, it is precisely the same error you described in your earlier post. However, I have not been able to update the generator twice without causing the in-place error. What could be the solution? How do I prevent the forward activations from becoming stale? Is it possible to update the generator twice while only updating the discriminator once.

You would need to recompute the forward pass using the generator and its already updated parameters to create the new forward activations. Alternatively, you could also delay the optimizerG.step() call until both backward passes calculated and accumulated the gradients using the same (initial) parameter set.

1 Like

Thank you so much @ptrblck