RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.DoubleTensor [8, 1]], which is output 0 of AsStridedBackward0, is at version 11; expected version 10 instead

I’ve been stuck on a problem for a while, and the error I get is as follows:

UserWarning: Error detected in AddmmBackward0. Traceback of forward call that caused the error:
  File "IC3Net/main_copy.py", line 578, in <module>
    run(args.num_epochs)
  File "IC3Net/main_copy.py", line 452, in run
    s, cpu_mem_peak, gpu_mem_peak = trainer.train_batch(ep)
  File "/home/kadhir/research/HetGAT_MARL_Communication/test/IC3Net/trainer.py", line 611, in train_batch
    batch, stat = self.run_batch(epoch)
  File "/home/kadhir/research/HetGAT_MARL_Communication/test/IC3Net/trainer.py", line 560, in run_batch
    episode, episode_stat = self.get_episode(epoch)
  File "/home/kadhir/research/HetGAT_MARL_Communication/test/IC3Net/trainer.py", line 188, in get_episode
    action_out, value, prev_hid = self.policy.batch_select_action_universal(x, self.stats['num_episodes'])
  File "/home/kadhir/research/HetGAT_MARL_Communication/test/IC3Net/hetgat/policy.py", line 343, in batch_select_action_universal
    results, c_v, a_v, prev_hid = self.model(x, g)
  File "/home/kadhir/anaconda3/envs/temp_hetgat/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/kadhir/research/HetGAT_MARL_Communication/test/IC3Net/hetgat/uavnet.py", line 349, in forward
    A_critic_value = self.A_critic_head(self.relu(h2['state'][1:, :]))
  File "/home/kadhir/anaconda3/envs/temp_hetgat/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/kadhir/anaconda3/envs/temp_hetgat/lib/python3.8/site-packages/torch/nn/modules/linear.py", line 103, in forward
    return F.linear(input, self.weight, self.bias)
  File "/home/kadhir/anaconda3/envs/temp_hetgat/lib/python3.8/site-packages/torch/nn/functional.py", line 1848, in linear
    return torch._C._nn.linear(input, weight, bias)
 (Triggered internally at  /opt/conda/conda-bld/pytorch_1639180588308/work/torch/csrc/autograd/python_anomaly_mode.cpp:104.)
  Variable._execution_engine.run_backward(
Traceback (most recent call last):
  File "IC3Net/main_copy.py", line 578, in <module>
    run(args.num_epochs)
  File "IC3Net/main_copy.py", line 452, in run
    s, cpu_mem_peak, gpu_mem_peak = trainer.train_batch(ep)
  File "/home/kadhir/research/HetGAT_MARL_Communication/test/IC3Net/trainer.py", line 621, in train_batch
    s = self.compute_grad(batch)
  File "/home/kadhir/research/HetGAT_MARL_Communication/test/IC3Net/trainer.py", line 520, in compute_grad
    loss = self.policy.batch_finish_per_class(
  File "/home/kadhir/research/HetGAT_MARL_Communication/test/IC3Net/hetgat/policy.py", line 701, in batch_finish_per_class
    total_loss.backward(retain_graph=False)
  File "/home/kadhir/anaconda3/envs/temp_hetgat/lib/python3.8/site-packages/torch/_tensor.py", line 307, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/home/kadhir/anaconda3/envs/temp_hetgat/lib/python3.8/site-packages/torch/autograd/__init__.py", line 154, in backward
    Variable._execution_engine.run_backward(
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.DoubleTensor [8, 1]], which is output 0 of AsStridedBackward0, is at version 11; expected version 10 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

I’ve tried to trace the problem to the source but to no avail. Can anyone please give me some insight on what might be going on?

Here is a snippet of interest. The codebase is huge, so I can post more snippets from other files if you want to take a look at those. I’ve been stuck on this for a while so any kind of insight will be greatly appreciated!

# uavnet.py:forward()
def forward(self, x, g):
        x, extras = self.forward_state_encoder(x)

        # hidden_state = hidden_state.to(self.device)
        # cell_state = cell_state.to(self.device)
        hidden_state_per_stat = extras['P_s'][0].to(self.device)
        hidden_state_act_stat = extras['A_s'][0].to(self.device)
        hidden_state_per_obs = extras['P_o'][0].to(self.device)
        cell_state_per_stat = extras['P_s'][1].to(self.device)
        cell_state_act_stat = extras['A_s'][1].to(self.device)
        cell_state_per_obs = extras['P_o'][1].to(self.device)

        if self.action_vision >= 0 or self.small_action_observation_space:
            # action should be full obs * num_squares, action is just one observation
            x_per_stat, x_act_stat = self.remove_excess_action_features_from_all_uneven_state(x)
            if self.action_vision >= 0:
                x_per_obs, x_act_obs = self.get_obs_features_uneven_obs(x)
                x_per_obs = x_per_obs.to(self.device)
                x_act_obs = x_act_obs.to(self.device)
            else:
                x_per_obs = self.get_obs_features(x).to(self.device)
        elif not self.tensor_obs:
            x_per_stat, x_act_stat = self.remove_excess_action_features_from_all(x)
            x_per_obs = self.get_obs_features(x).to(self.device)
        else:
            x_per_stat, x_act_stat, x_per_obs = self.get_states_obs_from_tensor(x)
        feat_dict = {}

        '''
        Data preprocessing - P
        '''
        state_per_stat = torch.Tensor(x_per_stat).to(self.device)

        if not self.tensor_obs:
            state_per_stat = self.relu(self.prepro_state(state_per_stat)) # output is 2,225
        else:
            state_per_stat = self.prepro_state(state_per_stat)
            state_per_stat = self.avgpool(state_per_stat)
            # Np x dim x 1 x 1
            state_per_stat = torch.flatten(state_per_stat, 1)
            state_per_stat = self.prepro_state_2(state_per_stat)

        hidden_state_per_stat, cell_state_per_stat = self.f_module_stat(state_per_stat,
                                                                        (hidden_state_per_stat, cell_state_per_stat))

        x_per_obs = x_per_obs.to(self.device)
        if not self.tensor_obs:
            x_per_obs = self.relu(self.prepro_obs(x_per_obs))
        else:
            # check if following line works for vision = 1
            x_per_obs = x_per_obs.squeeze()
            x_per_obs = self.prepro_obs(x_per_obs)
            if self.vision == 1:
                x_per_obs = self.avgpool(x_per_obs)
                # Np x dim x 1 x 1
                x_per_obs = torch.flatten(x_per_obs, 1)
            x_per_obs = self.prepro_obs_2(x_per_obs)

        hidden_state_per_obs, cell_state_per_obs = self.f_module_obs(x_per_obs.reshape((x_per_obs.shape[1], x_per_obs.shape[2])),
                                                                     (hidden_state_per_obs, cell_state_per_obs))
        feat_dict['P'] = torch.cat([hidden_state_per_stat, hidden_state_per_obs], dim=1)

        state_act = torch.Tensor(x_act_stat).to(self.device)

        if not self.tensor_obs:
            if self.action_vision >= 0 or self.small_action_observation_space:
                state_act = self.relu(self.prepro_state_for_action(state_act))
            else:
                state_act = self.relu(self.prepro_state(state_act))
        else:
            state_act = self.prepro_stat(state_act)
            state_act = self.avgpool(state_act)
            # Np x dim x 1 x 1
            state_act = torch.flatten(state_act, 1)
            state_act = self.prepro_state_2(state_act)

        hidden_state_act_stat, cell_state_act_stat = self.f_module_stat(state_act,
                                                                        (hidden_state_act_stat, cell_state_act_stat))

        if self.action_vision >= 0:
            x_act_obs = x_act_obs.to(self.device)
            if not self.tensor_obs:
                x_act_obs = self.relu(self.prepro_obs_for_action(x_act_obs)).reshape(x_act_obs.shape[1],-1)
            else:
                raise NotImplementedError

            # hidden_state_act_obs, cell_state_act_obs = self.f_module_obs(x_act_obs.squeeze(),
            #                                                              (hidden_state_act_obs, cell_state_act_obs))
            feat_dict['A'] = torch.cat([hidden_state_act_stat, x_act_obs], dim=1)
        else:

            feat_dict['A'] = hidden_state_act_stat

        # complete_state = torch.cat([state_per, state_act])
        # hidden_state, cell_state = self.f_module(complete_state.squeeze(), (hidden_state, cell_state))
        # feat_dict['P'] = hidden_state[0:2]
        # feat_dict['A'] = hidden_state[2].reshape(1, 32)
        '''
        # Np x 1 x H x W
        x = self.features(raw_f_d['P_s'])
        # Np x dim x H/4 x W/4
        x = self.avgpool(x)
        # Np x dim x 1 x 1
        p_s = torch.flatten(x, 1) # this works well when Np == 1
        # Np x dim
        '''
        '''
        Data preprocessing - A
        # '''
        # status_A = self.prepro(raw_f_d['A'])

        # feat_dict['A'] = status_A

        # add state node
        if self.with_two_state:
            feat_dict['state'] = torch.tensor([
                [self.num_P, self.num_A, self.world_dim, self.total_state_action_in_batch],
                [self.num_P, self.num_A, self.world_dim, self.total_state_action_in_batch]
            ]).to(self.device)
        else:
            feat_dict['state'] = torch.tensor([self.num_P, self.num_A, self.world_dim, self.total_state_action_in_batch]).to(self.device)

        h1 = self.layer1(g, feat_dict)
        h2 = self.layer2(g, h1)

        # get critic prediction, 1x1
        if self.per_class_critic:
            if self.with_two_state:
                # h2['state'] is 2 x dim, first 1 x dim is p state, second is a state
                P_critic_value = self.P_critic_head(self.relu(h2['state'][:1, :]))
                A_critic_value = self.A_critic_head(self.relu(h2['state'][1:, :]))
                if self.use_tanh:
                    P_critic_value = self.tanh(P_critic_value)
                    A_critic_value = self.tanh(A_critic_value)
            else:
                P_critic_value = self.P_critic_head(self.relu(h2['state']))
                A_critic_value = self.A_critic_head(self.relu(h2['state']))
                if self.use_tanh:
                    P_critic_value = self.tanh(P_critic_value)
                    A_critic_value = self.tanh(A_critic_value)

            h = {}
            h['P_s'] = hidden_state_per_stat, cell_state_per_stat
            h['P_o'] = hidden_state_per_obs, cell_state_per_obs
            h['A_s'] = hidden_state_act_stat, cell_state_act_stat

            return h2, P_critic_value, A_critic_value, h
        elif self.per_agent_critic:
            # P agents
            # num_P = len(h2['P'])
            num_P = self.num_P
            tmp_list = []
            for i in range(num_P):
                tmp_list.append(h2['state'])

            hp_emb = h2['P']  # num_P x out['P']
            hs_emb = torch.cat(tmp_list)  # num_P x out['state']
            P_critic_input = torch.cat((hp_emb, hs_emb), dim=1)  # num_P x out['P'+'state']
            P_critic_value = self.P_critic_head(self.relu(P_critic_input))  # num_P x 1
            # A agents
            num_A = len(h2['A'])
            tmp_list = []
            for i in range(num_A):
                tmp_list.append(h2['state'])

            ha_emb = h2['A']  # num_A x out['P']
            hs_emb = torch.cat(tmp_list)  # num_A x out['state']
            A_critic_input = torch.cat((ha_emb, hs_emb), dim=1)  # num_A x out['A'+'state']
            A_critic_value = self.A_critic_head(self.relu(A_critic_input))  # num_A x 1

            if self.use_tanh:
                P_critic_value = self.tanh(P_critic_value)
                A_critic_value = self.tanh(A_critic_value)
            h = {}
            h['P_s'] = hidden_state_per_stat, cell_state_per_stat
            h['P_o'] = hidden_state_per_obs, cell_state_per_obs
            h['A_s'] = hidden_state_act_stat, cell_state_act_stat
            return h2, P_critic_value, A_critic_value, h
        else:
            critic_value = self.critic_head(self.relu(h2['state']))
            if self.use_tanh:
                critic_value = self.relu(critic_value)

            h = {}
            h['P_s'] = hidden_state_per_stat, cell_state_per_stat
            h['P_o'] = hidden_state_per_obs, cell_state_per_obs

            h['A_s'] = hidden_state_act_stat, cell_state_act_stat

            return h2, critic_value, h

The error you’ve got usually occurs in the following cases

param[0] = 0
param[0] = torch.Tensor(~~~)

where param is a variable which has a gradient to be calculated.
Check whether you’re trying to manually assign a value to variable.

Hi thecho7, thanks for your reply. I will look into this but wanted to run it by your first. I have some code that does this:

                for p in self.params:
                        if p._grad is not None:
                            p._grad.data /= stat['num_steps']

Do you mean that it is incorrect to do that?

I guess that’s a problem.
Remove that and look how it changes