Inplace operation error in CQL code

Hi!
Although I’ve read many posts on the “inplace operation” error, I still haven’t been able to fix my code.
It was working in Torch v1.2, but is no longer working in Python 3.8.6 and Torch v1.7.

Here is the error message (after it refers to the culprit file): File

"/lustre/scratch/scratch/ucemea0/drl_myriad/cql.py", line 148, in update
    self.q_net2.forward(states, new_actions)
  File "/lustre/scratch/scratch/ucemea0/drl_myriad/models.py", line 52, in forw\
ard
    x = self.linear3(x)
  File "/lustre/home/ucemea0/venv_enrico/lib/python3.8/site-packages/torch/nn/m\
odules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/lustre/home/ucemea0/venv_enrico/lib/python3.8/site-packages/torch/nn/m\
odules/linear.py", line 93, in forward
    return F.linear(input, self.weight, self.bias)
  File "/lustre/home/ucemea0/venv_enrico/lib/python3.8/site-packages/torch/nn/f\
unctional.py", line 1690, in linear
    ret = torch.addmm(bias, input, weight.t()) (Triggered internally at  /pytorch/torch/csrc/autograd/python_anomaly_mode.cpp\
:104.)
  Variable._execution_engine.run_backward(
Traceback (most recent call last):
  File "/home/ucemea0/Scratch/drl_myriad/runTest0.py", line 12, in <module>
    runTestCql(0)
  File "/lustre/scratch/scratch/ucemea0/drl_myriad/cql_test.py", line 73, in ru\
nTestCql
    episode_rewards = mini_batch_train(env, agent, n_episodes, n_steps,batch_si\
ze)
  File "/lustre/scratch/scratch/ucemea0/drl_myriad/utils.py", line 21, in mini_\batch_train
    agent.update(batch_size)
  File "/lustre/scratch/scratch/ucemea0/drl_myriad/cql.py", line 222, in update
    policy_loss.backward(retain_graph=False)
  File "/lustre/home/ucemea0/venv_enrico/lib/python3.8/site-packages/torch/tens\
or.py", line 221, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/lustre/home/ucemea0/venv_enrico/lib/python3.8/site-packages/torch/auto\
grad/__init__.py", line 130, in backward
    Variable._execution_engine.run_backward(
RuntimeError: one of the variables needed for gradient computation has been mod\ ified by an inplace operation: [torch.FloatTensor [256, 1]], which is output 0 \
of TBackward, is at version 11; expected version 10 instead. Hint: the backtrac\
e further above shows the operation that failed to compute its gradient. The va\
riable in question was changed in there or anywhere later. Good luck!

Here is the main CQL code I have:

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Dec 24 11:30:42 2020

@author: enricoanderlini
"""

import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Normal

from models import SoftQNetwork, PolicyNetwork
from replay_buffers import BasicBuffer


class CQLAgent:
    # Class initialiser:
    def __init__(self, env, discount_rate=0.99, reward_scale=1.0, 
                 soft_target_tau=1e-2, policy_eval_start=10,
                 q_lr=1e-3, policy_lr=1e-3, buffer_maxlen=1000000, temp=1.0,
                 min_q_weight=1.0, max_q_backup=False, 
                 num_random=10, with_lagrange=False, lagrange_thresh=0.0):
        
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # Debugging:                                                            
        torch.autograd.set_detect_anomaly(True)
        
        self.env = env
        self.action_range = [env.action_space.low, env.action_space.high]
        #[env.action_space_low, env.action_space_high]      
        self.obs_dim = env.observation_space.shape[0]
        self.action_dim = env.action_space.shape[0]

        # Hyperparameters:
        self.gamma = discount_rate
        self.tau = soft_target_tau
        self.reward_scale = reward_scale
        self.policy_eval_start = policy_eval_start
        self._current_epoch = 0
        self._policy_update_ctr = 0
        self._num_q_update_steps = 0
        self._num_policy_update_steps = 0
        self._num_policy_steps = 1
        
        # Initialise the networks - SAC:
        self.q_net1 = SoftQNetwork(self.obs_dim, self.action_dim).to(self.device)
        self.q_net2 = SoftQNetwork(self.obs_dim, self.action_dim).to(self.device)
        self.target_q_net1 = SoftQNetwork(self.obs_dim, self.action_dim).to(self.device)
        self.target_q_net2 = SoftQNetwork(self.obs_dim, self.action_dim).to(self.device)
        self.policy_net = PolicyNetwork(self.obs_dim, self.action_dim).to(self.device)

        # Copy the paramaters to the target parameters - SAC:
        for target_param, param in zip(self.target_q_net1.parameters(), self.q_net1.parameters()):
            target_param.data.copy_(param)

        for target_param, param in zip(self.target_q_net2.parameters(), self.q_net2.parameters()):
            target_param.data.copy_(param)

        # Initialize the optimisers - SAC:
        self.q1_optimizer = optim.Adam(self.q_net1.parameters(), lr=q_lr)
        self.q2_optimizer = optim.Adam(self.q_net2.parameters(), lr=q_lr)
        self.policy_optimizer = optim.Adam(self.policy_net.parameters(), lr=policy_lr)

        # Entropy temperature - SAC:
        self.target_entropy = -torch.prod(torch.Tensor(self.env.action_space.shape).to(self.device)).item()
        self.log_alpha = torch.zeros(1, requires_grad=True, device=self.device)
        self.alpha = self.log_alpha.exp()
        self.alpha_optim = optim.Adam([self.log_alpha], lr=policy_lr) #a_lr??
        self.alpha = self.log_alpha.exp()

        # Initialise the replay buffer:
        self.replay_buffer = BasicBuffer(buffer_maxlen)
        
        # Back-up:
        self.with_lagrange = with_lagrange
        if self.with_lagrange:
            self.target_action_gap = lagrange_thresh
            self.log_alpha_prime = torch.zeros(1, requires_grad=True)
            self.alpha_prime_optimizer = optim.Adam([self.log_alpha_prime], 
                                                    lr=q_lr) #a_lr??
        self.max_q_backup = max_q_backup
        self.num_random = num_random
        
        # CQL:
        self.temperature = temp
        self.min_q_weight = min_q_weight
    

    def get_action(self, state):
        state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
        mean, log_std = self.policy_net.forward(state)
        std = log_std.exp()
        
        normal = Normal(mean, std)
        z = normal.sample()
        action = torch.tanh(z)
        action = action.cpu().detach().squeeze(0).numpy()
        
        return self.rescale_action(action)
    
    def rescale_action(self, action):
        return action * (self.action_range[1] - self.action_range[0]) / 2.0 +\
            (self.action_range[1] + self.action_range[0]) / 2.0
    
    def _get_tensor_values(self, obs, actions, network=None):
        action_shape = actions.shape[0]
        obs_shape = obs.shape[0]
        num_repeat = int (action_shape / obs_shape)
        obs_temp = obs.unsqueeze(1).repeat(1, num_repeat, 1).view(obs.shape[0] * num_repeat, obs.shape[1])
        preds = network(obs_temp, actions)
        preds = preds.view(obs.shape[0], num_repeat, 1)
        return preds

    def _get_policy_actions(self, obs, num_actions, network=None):
        obs_temp = obs.unsqueeze(1).repeat(1, num_actions, 1).view(obs.shape[0] * num_actions, obs.shape[1])
        new_obs_actions, new_obs_log_pi = network.sample(obs_temp)
        return new_obs_actions, new_obs_log_pi.view(obs.shape[0], num_actions, 1)
    
    # Training step update:
    def update(self, batch_size):
        # Update the current epoch:
        self._current_epoch += 1
        
        # Sample a batch from the replay buffer:
        states, actions, rewards, next_states, dones = self.replay_buffer.sample(batch_size)
        states = torch.FloatTensor(states).to(self.device)
        actions = torch.FloatTensor(actions).to(self.device)
        rewards = torch.FloatTensor(rewards).to(self.device)
        next_states = torch.FloatTensor(next_states).to(self.device)
        dones = torch.FloatTensor(dones).to(self.device)
        dones = dones.view(dones.size(0), -1)
        
        # Obtain the new action by sampling from the policy network:
        new_actions, log_pi = self.policy_net.sample(states)
        next_actions, next_log_pi = self.policy_net.sample(next_states)
        # Update alpha and its loss:
        alpha_loss = (self.log_alpha * (-log_pi - self.target_entropy).detach()).mean()
        self.alpha_optim.zero_grad()
        alpha_loss.backward()
        self.alpha_optim.step()
        self.alpha = self.log_alpha.exp()
        # Get the policy loss:
        min_q = torch.min(
                self.q_net1.forward(states, new_actions),
                self.q_net2.forward(states, new_actions)
            )
        policy_loss = (self.alpha * log_pi - min_q).mean()
        
        # Attempt behavioural cloning for the first few epochs:
        if self._current_epoch < self.policy_eval_start:
            policy_log_prob = self.policy_net.log_prob(states, actions)
            policy_loss = (self.alpha * log_pi - policy_log_prob).mean()
        
        # Obtain the current and new Q-values:     
        next_q1 = self.target_q_net1(next_states, next_actions)
        next_q2 = self.target_q_net2(next_states, next_actions)
        next_q_target = torch.min(next_q1, next_q2) - self.alpha * next_log_pi
        expected_q = self.reward_scale * rewards + (1 - dones) * self.gamma * next_q_target
        curr_q1 = self.q_net1.forward(states, actions)
        curr_q2 = self.q_net2.forward(states, actions)      
        # Calculate the Q-function loss:
        q1_loss = F.mse_loss(curr_q1, expected_q.detach())
        q2_loss = F.mse_loss(curr_q2, expected_q.detach())
        
        # Add the CQL steps:
        random_actions_tensor = torch.FloatTensor(next_q1.shape[0] * self.num_random, actions.shape[-1]).uniform_(-1, 1)
        curr_actions_tensor, curr_log_pis = self._get_policy_actions(states, num_actions=self.num_random, network=self.policy_net)
        new_curr_actions_tensor, new_log_pis = self._get_policy_actions(next_states, num_actions=self.num_random, network=self.policy_net)
        q1_rand = self._get_tensor_values(states, random_actions_tensor, network=self.q_net1)
        q2_rand = self._get_tensor_values(states, random_actions_tensor, network=self.q_net2)
        q1_curr_actions = self._get_tensor_values(states, curr_actions_tensor, network=self.q_net1)
        q2_curr_actions = self._get_tensor_values(states, curr_actions_tensor, network=self.q_net2)
        q1_next_actions = self._get_tensor_values(states, new_curr_actions_tensor, network=self.q_net1)
        q2_next_actions = self._get_tensor_values(states, new_curr_actions_tensor, network=self.q_net2)

        q1_pred = self.q_net1(states, actions)
        q2_pred = self.q_net2(states, actions)
        cat_q1 = torch.cat(
            [q1_rand, q1_pred.unsqueeze(1), q1_next_actions, q1_curr_actions], 1
        )
        cat_q2 = torch.cat(
            [q2_rand, q2_pred.unsqueeze(1), q2_next_actions, q2_curr_actions], 1
        )
        std_q1 = torch.std(cat_q1, dim=1)
        std_q2 = torch.std(cat_q2, dim=1)
        
        min_q1_loss = torch.logsumexp(cat_q1 / self.temperature, dim=1,).mean() * self.min_q_weight * self.temperature
        min_q2_loss = torch.logsumexp(cat_q2 / self.temperature, dim=1,).mean() * self.min_q_weight * self.temperature
                    
        # Subtract the log likelihood of the data:
        min_q1_loss = min_q1_loss - q1_pred.mean() * self.min_q_weight
        min_q2_loss = min_q2_loss - q2_pred.mean() * self.min_q_weight
        
        if self.with_lagrange:
            alpha_prime = torch.clamp(self.log_alpha_prime.exp(), min=0.0, max=1000000.0)
            min_q1_loss = alpha_prime * (min_q1_loss - self.target_action_gap)
            min_q2_loss = alpha_prime * (min_q2_loss - self.target_action_gap)

            self.alpha_prime_optimizer.zero_grad()
            alpha_prime_loss = (-min_q1_loss - min_q2_loss)*0.5 
            alpha_prime_loss.backward(retain_graph=True)
            self.alpha_prime_optimizer.step()

        q1_loss = q1_loss + min_q1_loss
        q2_loss = q2_loss + min_q2_loss
        
        # Update the Q networks:
        self._num_q_update_steps += 1
        self.q1_optimizer.zero_grad()
        q1_loss.backward(retain_graph=True)
        self.q1_optimizer.step()
        self.q2_optimizer.zero_grad()
        q2_loss.backward(retain_graph=True)
        self.q2_optimizer.step()

        # Update the policy network:
        self._num_policy_update_steps += 1
        self.policy_optimizer.zero_grad()
        policy_loss.backward(retain_graph=False)
        self.policy_optimizer.step()

        # Soft update with the target networks:
        for target_param, param in zip(self.target_q_net1.parameters(), self.q_net1.parameters()):
            target_param.data.copy_(self.tau * param + (1 - self.tau) * target_param)
        for target_param, param in zip(self.target_q_net2.parameters(), self.q_net2.parameters()):
            target_param.data.copy_(self.tau * param + (1 - self.tau) * target_param)
        
        # Update the number of training steps:
        #self._n_train_steps_total += 1

And here is are the models:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal
from distributions import TanhNormal


MEAN_MIN = -9.0
MEAN_MAX = 9.0


def atanh(x):
    one_plus_x = (1 + x).clamp(min=1e-6)
    one_minus_x = (1 - x).clamp(min=1e-6)
    return 0.5*torch.log(one_plus_x/ one_minus_x)


class ValueNetwork(nn.Module):

    def __init__(self, input_dim, output_dim, init_w=3e-3):
        super(ValueNetwork, self).__init__()
        self.fc1 = nn.Linear(input_dim, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, output_dim)

        self.fc3.weight.data.uniform_(-init_w, init_w)
        self.fc3.bias.data.uniform_(-init_w, init_w)

    def forward(self, state):
        x = F.relu(self.fc1(state))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)

        return x


class SoftQNetwork(nn.Module):
    
    def __init__(self, num_inputs, num_actions, hidden_size=256, init_w=3e-3):
        super(SoftQNetwork, self).__init__()
        self.linear1 = nn.Linear(num_inputs + num_actions, hidden_size)
        self.linear2 = nn.Linear(hidden_size, hidden_size)
        self.linear3 = nn.Linear(hidden_size, 1)

        self.linear3.weight.data.uniform_(-init_w, init_w)
        self.linear3.bias.data.uniform_(-init_w, init_w)

    def forward(self, state, action):
        x = torch.cat([state, action], 1)
        x = F.relu(self.linear1(x))
        x = F.relu(self.linear2(x))
        x = self.linear3(x)
        return x


class PolicyNetwork(nn.Module):
    
    def __init__(self, num_inputs, num_actions, hidden_size=256, init_w=3e-3,
                 log_std_min=-20, log_std_max=2):
        super(PolicyNetwork, self).__init__()
        self.log_std_min = log_std_min
        self.log_std_max = log_std_max

        self.linear1 = nn.Linear(num_inputs, hidden_size)
        self.linear2 = nn.Linear(hidden_size, hidden_size)

        self.mean_linear = nn.Linear(hidden_size, num_actions)
        self.mean_linear.weight.data.uniform_(-init_w, init_w)
        self.mean_linear.bias.data.uniform_(-init_w, init_w)

        self.log_std_linear = nn.Linear(hidden_size, num_actions)
        self.log_std_linear.weight.data.uniform_(-init_w, init_w)
        self.log_std_linear.bias.data.uniform_(-init_w, init_w)

    def forward(self, state):
        x = F.relu(self.linear1(state))
        x = F.relu(self.linear2(x))

        mean    = self.mean_linear(x)
        log_std = self.log_std_linear(x)
        log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)

        return mean, log_std

    def sample(self, state, epsilon=1e-6):
        mean, log_std = self.forward(state)
        std = log_std.exp()

        normal = Normal(mean, std)
        z = normal.rsample()
        action = torch.tanh(z)

        log_pi = normal.log_prob(z) - torch.log(1 - action.pow(2) + epsilon)
        log_pi = log_pi.sum(1, keepdim=True)

        return action, log_pi
    
    def log_prob(self, obs, actions):
        raw_actions = atanh(actions)
        
        x = F.relu(self.linear1(obs))
        x = F.relu(self.linear2(x))
        
        mean    = self.mean_linear(x)
        mean = torch.clamp(mean, MEAN_MIN, MEAN_MAX)
        
        log_std = self.log_std_linear(x)
        log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
        std = torch.exp(log_std)

        tanh_normal = TanhNormal(mean, std)
        log_prob = tanh_normal.log_prob(value=actions, pre_tanh_value=raw_actions)
        return log_prob.sum(-1)

I am not exactly understanding the error message. Is it a version problem? Or am I actually doing a wrong inplace operation somewhere?

Many thanks!