I’ve been stuck on a problem for a while, and the error I get is as follows:
UserWarning: Error detected in AddmmBackward0. Traceback of forward call that caused the error:
File "IC3Net/main_copy.py", line 578, in <module>
run(args.num_epochs)
File "IC3Net/main_copy.py", line 452, in run
s, cpu_mem_peak, gpu_mem_peak = trainer.train_batch(ep)
File "/home/kadhir/research/HetGAT_MARL_Communication/test/IC3Net/trainer.py", line 611, in train_batch
batch, stat = self.run_batch(epoch)
File "/home/kadhir/research/HetGAT_MARL_Communication/test/IC3Net/trainer.py", line 560, in run_batch
episode, episode_stat = self.get_episode(epoch)
File "/home/kadhir/research/HetGAT_MARL_Communication/test/IC3Net/trainer.py", line 188, in get_episode
action_out, value, prev_hid = self.policy.batch_select_action_universal(x, self.stats['num_episodes'])
File "/home/kadhir/research/HetGAT_MARL_Communication/test/IC3Net/hetgat/policy.py", line 343, in batch_select_action_universal
results, c_v, a_v, prev_hid = self.model(x, g)
File "/home/kadhir/anaconda3/envs/temp_hetgat/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/home/kadhir/research/HetGAT_MARL_Communication/test/IC3Net/hetgat/uavnet.py", line 349, in forward
A_critic_value = self.A_critic_head(self.relu(h2['state'][1:, :]))
File "/home/kadhir/anaconda3/envs/temp_hetgat/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/home/kadhir/anaconda3/envs/temp_hetgat/lib/python3.8/site-packages/torch/nn/modules/linear.py", line 103, in forward
return F.linear(input, self.weight, self.bias)
File "/home/kadhir/anaconda3/envs/temp_hetgat/lib/python3.8/site-packages/torch/nn/functional.py", line 1848, in linear
return torch._C._nn.linear(input, weight, bias)
(Triggered internally at /opt/conda/conda-bld/pytorch_1639180588308/work/torch/csrc/autograd/python_anomaly_mode.cpp:104.)
Variable._execution_engine.run_backward(
Traceback (most recent call last):
File "IC3Net/main_copy.py", line 578, in <module>
run(args.num_epochs)
File "IC3Net/main_copy.py", line 452, in run
s, cpu_mem_peak, gpu_mem_peak = trainer.train_batch(ep)
File "/home/kadhir/research/HetGAT_MARL_Communication/test/IC3Net/trainer.py", line 621, in train_batch
s = self.compute_grad(batch)
File "/home/kadhir/research/HetGAT_MARL_Communication/test/IC3Net/trainer.py", line 520, in compute_grad
loss = self.policy.batch_finish_per_class(
File "/home/kadhir/research/HetGAT_MARL_Communication/test/IC3Net/hetgat/policy.py", line 701, in batch_finish_per_class
total_loss.backward(retain_graph=False)
File "/home/kadhir/anaconda3/envs/temp_hetgat/lib/python3.8/site-packages/torch/_tensor.py", line 307, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
File "/home/kadhir/anaconda3/envs/temp_hetgat/lib/python3.8/site-packages/torch/autograd/__init__.py", line 154, in backward
Variable._execution_engine.run_backward(
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.DoubleTensor [8, 1]], which is output 0 of AsStridedBackward0, is at version 11; expected version 10 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!
I’ve tried to trace the problem to the source but to no avail. Can anyone please give me some insight on what might be going on?
Here is a snippet of interest. The codebase is huge, so I can post more snippets from other files if you want to take a look at those. I’ve been stuck on this for a while so any kind of insight will be greatly appreciated!
# uavnet.py:forward()
def forward(self, x, g):
x, extras = self.forward_state_encoder(x)
# hidden_state = hidden_state.to(self.device)
# cell_state = cell_state.to(self.device)
hidden_state_per_stat = extras['P_s'][0].to(self.device)
hidden_state_act_stat = extras['A_s'][0].to(self.device)
hidden_state_per_obs = extras['P_o'][0].to(self.device)
cell_state_per_stat = extras['P_s'][1].to(self.device)
cell_state_act_stat = extras['A_s'][1].to(self.device)
cell_state_per_obs = extras['P_o'][1].to(self.device)
if self.action_vision >= 0 or self.small_action_observation_space:
# action should be full obs * num_squares, action is just one observation
x_per_stat, x_act_stat = self.remove_excess_action_features_from_all_uneven_state(x)
if self.action_vision >= 0:
x_per_obs, x_act_obs = self.get_obs_features_uneven_obs(x)
x_per_obs = x_per_obs.to(self.device)
x_act_obs = x_act_obs.to(self.device)
else:
x_per_obs = self.get_obs_features(x).to(self.device)
elif not self.tensor_obs:
x_per_stat, x_act_stat = self.remove_excess_action_features_from_all(x)
x_per_obs = self.get_obs_features(x).to(self.device)
else:
x_per_stat, x_act_stat, x_per_obs = self.get_states_obs_from_tensor(x)
feat_dict = {}
'''
Data preprocessing - P
'''
state_per_stat = torch.Tensor(x_per_stat).to(self.device)
if not self.tensor_obs:
state_per_stat = self.relu(self.prepro_state(state_per_stat)) # output is 2,225
else:
state_per_stat = self.prepro_state(state_per_stat)
state_per_stat = self.avgpool(state_per_stat)
# Np x dim x 1 x 1
state_per_stat = torch.flatten(state_per_stat, 1)
state_per_stat = self.prepro_state_2(state_per_stat)
hidden_state_per_stat, cell_state_per_stat = self.f_module_stat(state_per_stat,
(hidden_state_per_stat, cell_state_per_stat))
x_per_obs = x_per_obs.to(self.device)
if not self.tensor_obs:
x_per_obs = self.relu(self.prepro_obs(x_per_obs))
else:
# check if following line works for vision = 1
x_per_obs = x_per_obs.squeeze()
x_per_obs = self.prepro_obs(x_per_obs)
if self.vision == 1:
x_per_obs = self.avgpool(x_per_obs)
# Np x dim x 1 x 1
x_per_obs = torch.flatten(x_per_obs, 1)
x_per_obs = self.prepro_obs_2(x_per_obs)
hidden_state_per_obs, cell_state_per_obs = self.f_module_obs(x_per_obs.reshape((x_per_obs.shape[1], x_per_obs.shape[2])),
(hidden_state_per_obs, cell_state_per_obs))
feat_dict['P'] = torch.cat([hidden_state_per_stat, hidden_state_per_obs], dim=1)
state_act = torch.Tensor(x_act_stat).to(self.device)
if not self.tensor_obs:
if self.action_vision >= 0 or self.small_action_observation_space:
state_act = self.relu(self.prepro_state_for_action(state_act))
else:
state_act = self.relu(self.prepro_state(state_act))
else:
state_act = self.prepro_stat(state_act)
state_act = self.avgpool(state_act)
# Np x dim x 1 x 1
state_act = torch.flatten(state_act, 1)
state_act = self.prepro_state_2(state_act)
hidden_state_act_stat, cell_state_act_stat = self.f_module_stat(state_act,
(hidden_state_act_stat, cell_state_act_stat))
if self.action_vision >= 0:
x_act_obs = x_act_obs.to(self.device)
if not self.tensor_obs:
x_act_obs = self.relu(self.prepro_obs_for_action(x_act_obs)).reshape(x_act_obs.shape[1],-1)
else:
raise NotImplementedError
# hidden_state_act_obs, cell_state_act_obs = self.f_module_obs(x_act_obs.squeeze(),
# (hidden_state_act_obs, cell_state_act_obs))
feat_dict['A'] = torch.cat([hidden_state_act_stat, x_act_obs], dim=1)
else:
feat_dict['A'] = hidden_state_act_stat
# complete_state = torch.cat([state_per, state_act])
# hidden_state, cell_state = self.f_module(complete_state.squeeze(), (hidden_state, cell_state))
# feat_dict['P'] = hidden_state[0:2]
# feat_dict['A'] = hidden_state[2].reshape(1, 32)
'''
# Np x 1 x H x W
x = self.features(raw_f_d['P_s'])
# Np x dim x H/4 x W/4
x = self.avgpool(x)
# Np x dim x 1 x 1
p_s = torch.flatten(x, 1) # this works well when Np == 1
# Np x dim
'''
'''
Data preprocessing - A
# '''
# status_A = self.prepro(raw_f_d['A'])
# feat_dict['A'] = status_A
# add state node
if self.with_two_state:
feat_dict['state'] = torch.tensor([
[self.num_P, self.num_A, self.world_dim, self.total_state_action_in_batch],
[self.num_P, self.num_A, self.world_dim, self.total_state_action_in_batch]
]).to(self.device)
else:
feat_dict['state'] = torch.tensor([self.num_P, self.num_A, self.world_dim, self.total_state_action_in_batch]).to(self.device)
h1 = self.layer1(g, feat_dict)
h2 = self.layer2(g, h1)
# get critic prediction, 1x1
if self.per_class_critic:
if self.with_two_state:
# h2['state'] is 2 x dim, first 1 x dim is p state, second is a state
P_critic_value = self.P_critic_head(self.relu(h2['state'][:1, :]))
A_critic_value = self.A_critic_head(self.relu(h2['state'][1:, :]))
if self.use_tanh:
P_critic_value = self.tanh(P_critic_value)
A_critic_value = self.tanh(A_critic_value)
else:
P_critic_value = self.P_critic_head(self.relu(h2['state']))
A_critic_value = self.A_critic_head(self.relu(h2['state']))
if self.use_tanh:
P_critic_value = self.tanh(P_critic_value)
A_critic_value = self.tanh(A_critic_value)
h = {}
h['P_s'] = hidden_state_per_stat, cell_state_per_stat
h['P_o'] = hidden_state_per_obs, cell_state_per_obs
h['A_s'] = hidden_state_act_stat, cell_state_act_stat
return h2, P_critic_value, A_critic_value, h
elif self.per_agent_critic:
# P agents
# num_P = len(h2['P'])
num_P = self.num_P
tmp_list = []
for i in range(num_P):
tmp_list.append(h2['state'])
hp_emb = h2['P'] # num_P x out['P']
hs_emb = torch.cat(tmp_list) # num_P x out['state']
P_critic_input = torch.cat((hp_emb, hs_emb), dim=1) # num_P x out['P'+'state']
P_critic_value = self.P_critic_head(self.relu(P_critic_input)) # num_P x 1
# A agents
num_A = len(h2['A'])
tmp_list = []
for i in range(num_A):
tmp_list.append(h2['state'])
ha_emb = h2['A'] # num_A x out['P']
hs_emb = torch.cat(tmp_list) # num_A x out['state']
A_critic_input = torch.cat((ha_emb, hs_emb), dim=1) # num_A x out['A'+'state']
A_critic_value = self.A_critic_head(self.relu(A_critic_input)) # num_A x 1
if self.use_tanh:
P_critic_value = self.tanh(P_critic_value)
A_critic_value = self.tanh(A_critic_value)
h = {}
h['P_s'] = hidden_state_per_stat, cell_state_per_stat
h['P_o'] = hidden_state_per_obs, cell_state_per_obs
h['A_s'] = hidden_state_act_stat, cell_state_act_stat
return h2, P_critic_value, A_critic_value, h
else:
critic_value = self.critic_head(self.relu(h2['state']))
if self.use_tanh:
critic_value = self.relu(critic_value)
h = {}
h['P_s'] = hidden_state_per_stat, cell_state_per_stat
h['P_o'] = hidden_state_per_obs, cell_state_per_obs
h['A_s'] = hidden_state_act_stat, cell_state_act_stat
return h2, critic_value, h