Questions about gradient (Modification In-Place, ~Data Augmentation and Tensor Attributes)

Hello,
I am currently trying to learn some reinforcement learning and I have some issues with the management of gradient.
I am in a situation where I have several units! on a board, and I create my input this way:

main_input = torch.stack([self.translated(agent_state.main_input, unit.x, unit.y) for unit in agent_state.player.units])
    @staticmethod
    @torch.no_grad()
    def translated(matrix, x, y):
        return torch.roll(torch.roll(matrix, x, dims=-2), y, dims=-1)

I create my input this way in my forward pass because I intend to use memory replay, and saving the whole state would be too big (dim = (n_unit, 5, 20, 20)). Therefore I create a (5,20,20) tensor that I translate.

I wanted to know at which point I should do a “no grad” or else to avoid issue with the gradient step.(while creating the “main_input” in the “AgentState” class ? or maybe it is okay to just put it as a decorator inside the translate function ?)

Moreover I noticed that when I was debugging that my Tensor seems to be repeating in himself in his “data” attribute, which I guess is not normal, and I don’t seem to find how this behavior can happen.

Lastly, after a long time of investigation, I still get an error, which is certainly in relation with my first two interrogations.
"RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [50, 6]], which is output 0 of TBackward, is at version 3; expected version 1 instead.

If someone have some advices related to these issues, I would be very interested.
Thank you!
Thomas

Today I read a lot about in place operation and optimizer, I try to debug my code but I still could not find any in place operation, the full network is available here:

class FeatureExtractor(nn.Module):
    def __init__(self, main_input_channels, complementary_input_channels):
        super(FeatureExtractor, self).__init__()
        self.main_input_channels = main_input_channels
        self.complementary_input_channels = complementary_input_channels
        self.out_channels = 100

        self.conv1 = nn.Conv2d(in_channels=main_input_channels, out_channels=8, kernel_size=3, padding_mode='circular')
        self.conv2 = nn.Conv2d(in_channels=8, out_channels=16, kernel_size=3, padding_mode='circular')
        self.conv3 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding_mode='circular')
        self.pool = nn.MaxPool2d(2, 2)
        self.fc = nn.Linear(32 + complementary_input_channels, self.out_channels)

    def forward(self, x1, x2):
        x1 = self.pool(F.relu(self.conv1(x1)))
        x1 = self.pool(F.relu(self.conv2(x1)))
        x1 = F.relu(self.conv3(x1))
        x1 = x1.view(-1, 32)
        x2 = x2.unsqueeze(0).expand(x1.shape[0], -1)
        x = torch.cat((x1, x2), 1)
        x = F.relu(self.fc(x))
        return x


class ActorCriticNetwork(nn.Module):
    def __init__(self, main_input_channels, complementary_input_channels):
        super(ActorCriticNetwork, self).__init__()
        self.feature_extractor_actor = FeatureExtractor(main_input_channels, complementary_input_channels)
        self.fc_actor = nn.Linear(self.feature_extractor_actor.out_channels, 50)
        self.fc_actor_ship = nn.Linear(50, 6)
        self.fc_actor_shipyard = nn.Linear(50, 2)

        self.feature_extractor_critic = FeatureExtractor(main_input_channels, complementary_input_channels)
        self.fc_critic_1 = nn.Linear(self.feature_extractor_critic.out_channels, 50)
        self.fc_critic_2 = nn.Linear(50, 1)
        self.softmax = nn.Softmax(dim=1)

    def forward_actor_shipyard(self, agent_state: AgentState):
        main_input = torch.stack([self.translated(agent_state.main_input, shipyard.x, shipyard.y) for shipyard in agent_state.player.shipyards])
        x = self.feature_extractor_actor(main_input, agent_state.complementary_input)
        x = F.relu(self.fc_actor(x))
        x = self.fc_actor_shipyard(x)
        return self.softmax(x)

    def forward_actor_ship(self, agent_state: AgentState):
        main_input = torch.stack([self.translated(agent_state.main_input, ship.x, ship.y) for ship in agent_state.player.ships])
        x = self.feature_extractor_actor(main_input, agent_state.complementary_input)
        x = F.relu(self.fc_actor(x))
        x = self.fc_actor_ship(x)
        return self.softmax(x)

    def forward_critic(self, agent_state: AgentState):
        x = self.feature_extractor_critic(agent_state.main_input).unsqueeze(0), agent_state.complementary_input.clone())
        x = F.relu(self.fc_critic_1(x))
        x = self.fc_critic_2(x)
        return x

    @staticmethod
    @torch.no_grad()
    def translated(matrix, x, y):
        return torch.roll(torch.roll(matrix, x, dims=-2), y, dims=-1)

Actions predicted by the network are stored inside a container:

class Agent:
    def __init__(self, actor_critic_network: ActorCriticNetwork):
        self.network = actor_critic_network

    def __call__(self, agent_state: AgentState, counter=None) -> Action:
        prediction_shipyards, prediction_ships = [], []

        if agent_state.player.shipyards:
            prediction_shipyards = self.network.forward_actor_shipyard(agent_state)

        if agent_state.player.ships:
            prediction_ships = self.network.forward_actor_ship(agent_state)

        action = Action(prediction_shipyards, prediction_ships, agent_state, counter=counter)
        return action
class Action:
    def __init__(self, prediction_shipyards, prediction_ships, agent_state, counter=None):
        self.prediction_shipyards = prediction_shipyards
        self.prediction_ships = prediction_ships
...

and this method is used for learning:

    def learn(self, state: AgentState, action: Action, state_: AgentState):
        self.counter += 1

        value_pred = self.network.forward_critic(state)
        value_pred_ = self.network.forward_critic(state_)
        reward = state.value - state_.value
        target = reward + self.gamma * value_pred_.detach()
        delta = target - value_pred.detach()

        self.critic_loss += F.mse_loss(value_pred, target)
        actor_loss_function = custom_loss(delta)

        if state.player.shipyards:
            self.actor_loss += actor_loss_function(action.prediction_shipyards, action.tensor_action_shipyards)

        if state.player.ships:
            self.actor_loss += actor_loss_function(action.prediction_ships, action.tensor_action_ships)

        if self.counter % self.optimization_frequency == 0:
            self.actor_loss.backward()
            self.critic_loss.backward()
            self.optimizer_actor.step()
            self.optimizer_critic.step()
            self.optimizer_actor.zero_grad()
            self.optimizer_critic.zero_grad()
            self.actor_loss = 0
            self.critic_loss = 0

I get this trackback

Connected to pydev debugger (build 193.6494.30)
[W ..\torch\csrc\autograd\python_anomaly_mode.cpp:60] Warning: Error detected in AddmmBackward. Traceback of forward call that caused the error:
  File "C:\Program Files\JetBrains\PyCharm 2019.3.3\plugins\python\helpers\pydev\pydevd.py", line 2127, in <module>
    main()
  File "C:\Program Files\JetBrains\PyCharm 2019.3.3\plugins\python\helpers\pydev\pydevd.py", line 2118, in main
    globals = debugger.run(setup['file'], None, None, is_module)
  File "C:\Program Files\JetBrains\PyCharm 2019.3.3\plugins\python\helpers\pydev\pydevd.py", line 1427, in run
    return self._exec(is_module, entry_point_fn, module_name, file, globals, locals)
  File "C:\Program Files\JetBrains\PyCharm 2019.3.3\plugins\python\helpers\pydev\pydevd.py", line 1434, in _exec
    pydev_imports.execfile(file, globals, locals)  # execute the script
  File "C:\Program Files\JetBrains\PyCharm 2019.3.3\plugins\python\helpers\pydev\_pydev_imps\_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "D:/Thomas/Python/NewHalite/code/A3C.py", line 124, in <module>
    alt.run()
  File "D:/Thomas/Python/NewHalite/code/A3C.py", line 86, in run
    actions = self.agent(player_state[active_id], self.counter)
  File "D:\Thomas\Python\NewHalite\code\agent.py", line 18, in __call__
    prediction_ships = self.network.forward_actor_ship(agent_state)
  File "D:\Thomas\Python\NewHalite\code\network.py", line 58, in forward_actor_ship
    x = self.fc_actor_ship(x)
  File "C:\ProgramData\Anaconda3\lib\site-packages\torch\nn\modules\module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "C:\ProgramData\Anaconda3\lib\site-packages\torch\nn\modules\linear.py", line 91, in forward
    return F.linear(input, self.weight, self.bias)
  File "C:\ProgramData\Anaconda3\lib\site-packages\torch\nn\functional.py", line 1674, in linear
    ret = torch.addmm(bias, input, weight.t())
 (function print_stack)
Traceback (most recent call last):
  File "D:/Thomas/Python/NewHalite/code/A3C.py", line 124, in <module>
    alt.run()
  File "D:/Thomas/Python/NewHalite/code/A3C.py", line 97, in run
    self.learn(player_state[agent_id], player_actions[agent_id], player_state_[agent_id])
  File "D:/Thomas/Python/NewHalite/code/A3C.py", line 64, in learn
    self.actor_loss.backward()
  File "C:\ProgramData\Anaconda3\lib\site-packages\torch\tensor.py", line 185, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "C:\ProgramData\Anaconda3\lib\site-packages\torch\autograd\__init__.py", line 127, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [50, 6]], which is output 0 of TBackward, is at version 3; expected version 1 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

If you have any insights related to this issue I would be really interested!
Best,
Thomas

Looks like your inplace operation is += here:

self.actor_loss += actor_loss_function(action.prediction_ships, action.tensor_action_ships)

+= is fine when mutated tensor was not captured before, but here I’d guess that custom_loss() captures its output

1 Like

Thank you Alex
I tried to modify the three lines and unfortunately the error persist and the traceback is still the same:

Connected to pydev debugger (build 193.6494.30)
[W ..\torch\csrc\autograd\python_anomaly_mode.cpp:60] Warning: Error detected in AddmmBackward. Traceback of forward call that caused the error:
  File "C:\Program Files\JetBrains\PyCharm 2019.3.3\plugins\python\helpers\pydev\pydevd.py", line 2127, in <module>
    main()
  File "C:\Program Files\JetBrains\PyCharm 2019.3.3\plugins\python\helpers\pydev\pydevd.py", line 2118, in main
    globals = debugger.run(setup['file'], None, None, is_module)
  File "C:\Program Files\JetBrains\PyCharm 2019.3.3\plugins\python\helpers\pydev\pydevd.py", line 1427, in run
    return self._exec(is_module, entry_point_fn, module_name, file, globals, locals)
  File "C:\Program Files\JetBrains\PyCharm 2019.3.3\plugins\python\helpers\pydev\pydevd.py", line 1434, in _exec
    pydev_imports.execfile(file, globals, locals)  # execute the script
  File "C:\Program Files\JetBrains\PyCharm 2019.3.3\plugins\python\helpers\pydev\_pydev_imps\_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "D:/Thomas/Python/NewHalite/code/A3C.py", line 124, in <module>
    alt.run()
  File "D:/Thomas/Python/NewHalite/code/A3C.py", line 86, in run
    actions = self.agent(player_state[active_id], self.counter)
  File "D:\Thomas\Python\NewHalite\code\agent.py", line 18, in __call__
    prediction_ships = self.network.forward_actor_ship(agent_state)
  File "D:\Thomas\Python\NewHalite\code\network.py", line 58, in forward_actor_ship
    x = self.fc_actor_ship(x)
  File "C:\ProgramData\Anaconda3\lib\site-packages\torch\nn\modules\module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "C:\ProgramData\Anaconda3\lib\site-packages\torch\nn\modules\linear.py", line 91, in forward
    return F.linear(input, self.weight, self.bias)
  File "C:\ProgramData\Anaconda3\lib\site-packages\torch\nn\functional.py", line 1674, in linear
    ret = torch.addmm(bias, input, weight.t())
 (function print_stack)
Traceback (most recent call last):
  File "D:/Thomas/Python/NewHalite/code/A3C.py", line 124, in <module>
    alt.run()
  File "D:/Thomas/Python/NewHalite/code/A3C.py", line 97, in run
    self.learn(player_state[agent_id], player_actions[agent_id], player_state_[agent_id])
  File "D:/Thomas/Python/NewHalite/code/A3C.py", line 64, in learn
    self.actor_loss.backward()
  File "C:\ProgramData\Anaconda3\lib\site-packages\torch\tensor.py", line 185, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "C:\ProgramData\Anaconda3\lib\site-packages\torch\autograd\__init__.py", line 127, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [50, 6]], which is output 0 of TBackward, is at version 3; expected version 1 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

Can’t spot other potential issues. I think fc_actor_ship.weight is somehow changed when you skip optimization, so your pending backward graph becomes stale.

1 Like

Thank you for your answers! I will continue to search and post the answer if I can find the cause of it

By the way, can you explain what you meant with your first answer ? I am a little bit confused by the term ''mutated" and “captured” in this context

Thanks again

mutated = changed mutable object, in this case I mean that tensor data is modified inplace.
captured = reference to a tensor is stored by autograd engine (or with autograd.Function.save_for_backward). It is an error to change such tensors inplace (you can see that they have ._version attribute to track this, your error mentions these versions)

1 Like