Hi!
Although I’ve read many posts on the “inplace operation” error, I still haven’t been able to fix my code.
It was working in Torch v1.2, but is no longer working in Python 3.8.6 and Torch v1.7.
Here is the error message (after it refers to the culprit file): File
"/lustre/scratch/scratch/ucemea0/drl_myriad/cql.py", line 148, in update
self.q_net2.forward(states, new_actions)
File "/lustre/scratch/scratch/ucemea0/drl_myriad/models.py", line 52, in forw\
ard
x = self.linear3(x)
File "/lustre/home/ucemea0/venv_enrico/lib/python3.8/site-packages/torch/nn/m\
odules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/lustre/home/ucemea0/venv_enrico/lib/python3.8/site-packages/torch/nn/m\
odules/linear.py", line 93, in forward
return F.linear(input, self.weight, self.bias)
File "/lustre/home/ucemea0/venv_enrico/lib/python3.8/site-packages/torch/nn/f\
unctional.py", line 1690, in linear
ret = torch.addmm(bias, input, weight.t()) (Triggered internally at /pytorch/torch/csrc/autograd/python_anomaly_mode.cpp\
:104.)
Variable._execution_engine.run_backward(
Traceback (most recent call last):
File "/home/ucemea0/Scratch/drl_myriad/runTest0.py", line 12, in <module>
runTestCql(0)
File "/lustre/scratch/scratch/ucemea0/drl_myriad/cql_test.py", line 73, in ru\
nTestCql
episode_rewards = mini_batch_train(env, agent, n_episodes, n_steps,batch_si\
ze)
File "/lustre/scratch/scratch/ucemea0/drl_myriad/utils.py", line 21, in mini_\batch_train
agent.update(batch_size)
File "/lustre/scratch/scratch/ucemea0/drl_myriad/cql.py", line 222, in update
policy_loss.backward(retain_graph=False)
File "/lustre/home/ucemea0/venv_enrico/lib/python3.8/site-packages/torch/tens\
or.py", line 221, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)
File "/lustre/home/ucemea0/venv_enrico/lib/python3.8/site-packages/torch/auto\
grad/__init__.py", line 130, in backward
Variable._execution_engine.run_backward(
RuntimeError: one of the variables needed for gradient computation has been mod\ ified by an inplace operation: [torch.FloatTensor [256, 1]], which is output 0 \
of TBackward, is at version 11; expected version 10 instead. Hint: the backtrac\
e further above shows the operation that failed to compute its gradient. The va\
riable in question was changed in there or anywhere later. Good luck!
Here is the main CQL code I have:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Dec 24 11:30:42 2020
@author: enricoanderlini
"""
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Normal
from models import SoftQNetwork, PolicyNetwork
from replay_buffers import BasicBuffer
class CQLAgent:
# Class initialiser:
def __init__(self, env, discount_rate=0.99, reward_scale=1.0,
soft_target_tau=1e-2, policy_eval_start=10,
q_lr=1e-3, policy_lr=1e-3, buffer_maxlen=1000000, temp=1.0,
min_q_weight=1.0, max_q_backup=False,
num_random=10, with_lagrange=False, lagrange_thresh=0.0):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Debugging:
torch.autograd.set_detect_anomaly(True)
self.env = env
self.action_range = [env.action_space.low, env.action_space.high]
#[env.action_space_low, env.action_space_high]
self.obs_dim = env.observation_space.shape[0]
self.action_dim = env.action_space.shape[0]
# Hyperparameters:
self.gamma = discount_rate
self.tau = soft_target_tau
self.reward_scale = reward_scale
self.policy_eval_start = policy_eval_start
self._current_epoch = 0
self._policy_update_ctr = 0
self._num_q_update_steps = 0
self._num_policy_update_steps = 0
self._num_policy_steps = 1
# Initialise the networks - SAC:
self.q_net1 = SoftQNetwork(self.obs_dim, self.action_dim).to(self.device)
self.q_net2 = SoftQNetwork(self.obs_dim, self.action_dim).to(self.device)
self.target_q_net1 = SoftQNetwork(self.obs_dim, self.action_dim).to(self.device)
self.target_q_net2 = SoftQNetwork(self.obs_dim, self.action_dim).to(self.device)
self.policy_net = PolicyNetwork(self.obs_dim, self.action_dim).to(self.device)
# Copy the paramaters to the target parameters - SAC:
for target_param, param in zip(self.target_q_net1.parameters(), self.q_net1.parameters()):
target_param.data.copy_(param)
for target_param, param in zip(self.target_q_net2.parameters(), self.q_net2.parameters()):
target_param.data.copy_(param)
# Initialize the optimisers - SAC:
self.q1_optimizer = optim.Adam(self.q_net1.parameters(), lr=q_lr)
self.q2_optimizer = optim.Adam(self.q_net2.parameters(), lr=q_lr)
self.policy_optimizer = optim.Adam(self.policy_net.parameters(), lr=policy_lr)
# Entropy temperature - SAC:
self.target_entropy = -torch.prod(torch.Tensor(self.env.action_space.shape).to(self.device)).item()
self.log_alpha = torch.zeros(1, requires_grad=True, device=self.device)
self.alpha = self.log_alpha.exp()
self.alpha_optim = optim.Adam([self.log_alpha], lr=policy_lr) #a_lr??
self.alpha = self.log_alpha.exp()
# Initialise the replay buffer:
self.replay_buffer = BasicBuffer(buffer_maxlen)
# Back-up:
self.with_lagrange = with_lagrange
if self.with_lagrange:
self.target_action_gap = lagrange_thresh
self.log_alpha_prime = torch.zeros(1, requires_grad=True)
self.alpha_prime_optimizer = optim.Adam([self.log_alpha_prime],
lr=q_lr) #a_lr??
self.max_q_backup = max_q_backup
self.num_random = num_random
# CQL:
self.temperature = temp
self.min_q_weight = min_q_weight
def get_action(self, state):
state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
mean, log_std = self.policy_net.forward(state)
std = log_std.exp()
normal = Normal(mean, std)
z = normal.sample()
action = torch.tanh(z)
action = action.cpu().detach().squeeze(0).numpy()
return self.rescale_action(action)
def rescale_action(self, action):
return action * (self.action_range[1] - self.action_range[0]) / 2.0 +\
(self.action_range[1] + self.action_range[0]) / 2.0
def _get_tensor_values(self, obs, actions, network=None):
action_shape = actions.shape[0]
obs_shape = obs.shape[0]
num_repeat = int (action_shape / obs_shape)
obs_temp = obs.unsqueeze(1).repeat(1, num_repeat, 1).view(obs.shape[0] * num_repeat, obs.shape[1])
preds = network(obs_temp, actions)
preds = preds.view(obs.shape[0], num_repeat, 1)
return preds
def _get_policy_actions(self, obs, num_actions, network=None):
obs_temp = obs.unsqueeze(1).repeat(1, num_actions, 1).view(obs.shape[0] * num_actions, obs.shape[1])
new_obs_actions, new_obs_log_pi = network.sample(obs_temp)
return new_obs_actions, new_obs_log_pi.view(obs.shape[0], num_actions, 1)
# Training step update:
def update(self, batch_size):
# Update the current epoch:
self._current_epoch += 1
# Sample a batch from the replay buffer:
states, actions, rewards, next_states, dones = self.replay_buffer.sample(batch_size)
states = torch.FloatTensor(states).to(self.device)
actions = torch.FloatTensor(actions).to(self.device)
rewards = torch.FloatTensor(rewards).to(self.device)
next_states = torch.FloatTensor(next_states).to(self.device)
dones = torch.FloatTensor(dones).to(self.device)
dones = dones.view(dones.size(0), -1)
# Obtain the new action by sampling from the policy network:
new_actions, log_pi = self.policy_net.sample(states)
next_actions, next_log_pi = self.policy_net.sample(next_states)
# Update alpha and its loss:
alpha_loss = (self.log_alpha * (-log_pi - self.target_entropy).detach()).mean()
self.alpha_optim.zero_grad()
alpha_loss.backward()
self.alpha_optim.step()
self.alpha = self.log_alpha.exp()
# Get the policy loss:
min_q = torch.min(
self.q_net1.forward(states, new_actions),
self.q_net2.forward(states, new_actions)
)
policy_loss = (self.alpha * log_pi - min_q).mean()
# Attempt behavioural cloning for the first few epochs:
if self._current_epoch < self.policy_eval_start:
policy_log_prob = self.policy_net.log_prob(states, actions)
policy_loss = (self.alpha * log_pi - policy_log_prob).mean()
# Obtain the current and new Q-values:
next_q1 = self.target_q_net1(next_states, next_actions)
next_q2 = self.target_q_net2(next_states, next_actions)
next_q_target = torch.min(next_q1, next_q2) - self.alpha * next_log_pi
expected_q = self.reward_scale * rewards + (1 - dones) * self.gamma * next_q_target
curr_q1 = self.q_net1.forward(states, actions)
curr_q2 = self.q_net2.forward(states, actions)
# Calculate the Q-function loss:
q1_loss = F.mse_loss(curr_q1, expected_q.detach())
q2_loss = F.mse_loss(curr_q2, expected_q.detach())
# Add the CQL steps:
random_actions_tensor = torch.FloatTensor(next_q1.shape[0] * self.num_random, actions.shape[-1]).uniform_(-1, 1)
curr_actions_tensor, curr_log_pis = self._get_policy_actions(states, num_actions=self.num_random, network=self.policy_net)
new_curr_actions_tensor, new_log_pis = self._get_policy_actions(next_states, num_actions=self.num_random, network=self.policy_net)
q1_rand = self._get_tensor_values(states, random_actions_tensor, network=self.q_net1)
q2_rand = self._get_tensor_values(states, random_actions_tensor, network=self.q_net2)
q1_curr_actions = self._get_tensor_values(states, curr_actions_tensor, network=self.q_net1)
q2_curr_actions = self._get_tensor_values(states, curr_actions_tensor, network=self.q_net2)
q1_next_actions = self._get_tensor_values(states, new_curr_actions_tensor, network=self.q_net1)
q2_next_actions = self._get_tensor_values(states, new_curr_actions_tensor, network=self.q_net2)
q1_pred = self.q_net1(states, actions)
q2_pred = self.q_net2(states, actions)
cat_q1 = torch.cat(
[q1_rand, q1_pred.unsqueeze(1), q1_next_actions, q1_curr_actions], 1
)
cat_q2 = torch.cat(
[q2_rand, q2_pred.unsqueeze(1), q2_next_actions, q2_curr_actions], 1
)
std_q1 = torch.std(cat_q1, dim=1)
std_q2 = torch.std(cat_q2, dim=1)
min_q1_loss = torch.logsumexp(cat_q1 / self.temperature, dim=1,).mean() * self.min_q_weight * self.temperature
min_q2_loss = torch.logsumexp(cat_q2 / self.temperature, dim=1,).mean() * self.min_q_weight * self.temperature
# Subtract the log likelihood of the data:
min_q1_loss = min_q1_loss - q1_pred.mean() * self.min_q_weight
min_q2_loss = min_q2_loss - q2_pred.mean() * self.min_q_weight
if self.with_lagrange:
alpha_prime = torch.clamp(self.log_alpha_prime.exp(), min=0.0, max=1000000.0)
min_q1_loss = alpha_prime * (min_q1_loss - self.target_action_gap)
min_q2_loss = alpha_prime * (min_q2_loss - self.target_action_gap)
self.alpha_prime_optimizer.zero_grad()
alpha_prime_loss = (-min_q1_loss - min_q2_loss)*0.5
alpha_prime_loss.backward(retain_graph=True)
self.alpha_prime_optimizer.step()
q1_loss = q1_loss + min_q1_loss
q2_loss = q2_loss + min_q2_loss
# Update the Q networks:
self._num_q_update_steps += 1
self.q1_optimizer.zero_grad()
q1_loss.backward(retain_graph=True)
self.q1_optimizer.step()
self.q2_optimizer.zero_grad()
q2_loss.backward(retain_graph=True)
self.q2_optimizer.step()
# Update the policy network:
self._num_policy_update_steps += 1
self.policy_optimizer.zero_grad()
policy_loss.backward(retain_graph=False)
self.policy_optimizer.step()
# Soft update with the target networks:
for target_param, param in zip(self.target_q_net1.parameters(), self.q_net1.parameters()):
target_param.data.copy_(self.tau * param + (1 - self.tau) * target_param)
for target_param, param in zip(self.target_q_net2.parameters(), self.q_net2.parameters()):
target_param.data.copy_(self.tau * param + (1 - self.tau) * target_param)
# Update the number of training steps:
#self._n_train_steps_total += 1
And here is are the models:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal
from distributions import TanhNormal
MEAN_MIN = -9.0
MEAN_MAX = 9.0
def atanh(x):
one_plus_x = (1 + x).clamp(min=1e-6)
one_minus_x = (1 - x).clamp(min=1e-6)
return 0.5*torch.log(one_plus_x/ one_minus_x)
class ValueNetwork(nn.Module):
def __init__(self, input_dim, output_dim, init_w=3e-3):
super(ValueNetwork, self).__init__()
self.fc1 = nn.Linear(input_dim, 256)
self.fc2 = nn.Linear(256, 256)
self.fc3 = nn.Linear(256, output_dim)
self.fc3.weight.data.uniform_(-init_w, init_w)
self.fc3.bias.data.uniform_(-init_w, init_w)
def forward(self, state):
x = F.relu(self.fc1(state))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
class SoftQNetwork(nn.Module):
def __init__(self, num_inputs, num_actions, hidden_size=256, init_w=3e-3):
super(SoftQNetwork, self).__init__()
self.linear1 = nn.Linear(num_inputs + num_actions, hidden_size)
self.linear2 = nn.Linear(hidden_size, hidden_size)
self.linear3 = nn.Linear(hidden_size, 1)
self.linear3.weight.data.uniform_(-init_w, init_w)
self.linear3.bias.data.uniform_(-init_w, init_w)
def forward(self, state, action):
x = torch.cat([state, action], 1)
x = F.relu(self.linear1(x))
x = F.relu(self.linear2(x))
x = self.linear3(x)
return x
class PolicyNetwork(nn.Module):
def __init__(self, num_inputs, num_actions, hidden_size=256, init_w=3e-3,
log_std_min=-20, log_std_max=2):
super(PolicyNetwork, self).__init__()
self.log_std_min = log_std_min
self.log_std_max = log_std_max
self.linear1 = nn.Linear(num_inputs, hidden_size)
self.linear2 = nn.Linear(hidden_size, hidden_size)
self.mean_linear = nn.Linear(hidden_size, num_actions)
self.mean_linear.weight.data.uniform_(-init_w, init_w)
self.mean_linear.bias.data.uniform_(-init_w, init_w)
self.log_std_linear = nn.Linear(hidden_size, num_actions)
self.log_std_linear.weight.data.uniform_(-init_w, init_w)
self.log_std_linear.bias.data.uniform_(-init_w, init_w)
def forward(self, state):
x = F.relu(self.linear1(state))
x = F.relu(self.linear2(x))
mean = self.mean_linear(x)
log_std = self.log_std_linear(x)
log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
return mean, log_std
def sample(self, state, epsilon=1e-6):
mean, log_std = self.forward(state)
std = log_std.exp()
normal = Normal(mean, std)
z = normal.rsample()
action = torch.tanh(z)
log_pi = normal.log_prob(z) - torch.log(1 - action.pow(2) + epsilon)
log_pi = log_pi.sum(1, keepdim=True)
return action, log_pi
def log_prob(self, obs, actions):
raw_actions = atanh(actions)
x = F.relu(self.linear1(obs))
x = F.relu(self.linear2(x))
mean = self.mean_linear(x)
mean = torch.clamp(mean, MEAN_MIN, MEAN_MAX)
log_std = self.log_std_linear(x)
log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
std = torch.exp(log_std)
tanh_normal = TanhNormal(mean, std)
log_prob = tanh_normal.log_prob(value=actions, pre_tanh_value=raw_actions)
return log_prob.sum(-1)
I am not exactly understanding the error message. Is it a version problem? Or am I actually doing a wrong inplace operation somewhere?
Many thanks!