GRU throws 'RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation' after second .backward()

I was trying to implement the pytorch REINFORCE example relying on some basic RNN with GRU in it (similar in style to this).

While testing my implementation I get the error RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation always as I would call .backward() a second time. I used my ol’ pal Google to find my issue but I was not able to pinpoint it in my work. My suspicion is that it has something to do with the hidden state and the GRU and that I am not doind something in the intended way.

Here is a stripped down version of my code, where every value that is based the answer of the reinforcement learning environment is substituted with torch,rand().

import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical

torch.autograd.set_detect_anomaly(True)

class MinimalGru(nn.Module):
    def __init__(self):
        super(MinimalGru, self).__init__()

        self.rnn = nn.GRU(input_size=(7 + 3 + 1),
                                      hidden_size=8)
        self.policy_head = nn.Linear(8, 4)
        self.value_head = nn.Linear(8, 1)
        # cell state when restarting episode
        self.h0 = torch.zeros([1, 8], dtype=torch.float, requires_grad=True)
        self.rnn_hidden_state = None

    def initialize_hidden_state(self, batch_size):
        # shape (num_layers * num_directions, batch, hidden_size)
        self.rnn_hidden_state = torch.cat(h0.clone() for _ in range(batch_size)], dim=0).unsqueeze(0)

    def reset_hidden_state_element(self, index):
        self.rnn_hidden_state[0, index] = self.h0.clone()

    def forward(self, state, a_t, r_t):
        in_put = torch.cat((state, a_t, r_t),dim=1).unsqueeze(0)
        z, self.rnn_hidden_state = self.rnn(in_put, self.rnn_hidden_state)
        action_scores = self.policy_head(z.squeeze(0))
        value_score = self.value_head(z.squeeze(0))
        return torch.softmax(action_scores, dim=1), torch.sigmoid(value_score)

def reinforce():
    number_of_episodes = 123456
    length_of_episode = 123
    model = MinimalGru()
    model.initialize_hidden_state(batch_size=1)
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    for i_episode in range(number_of_episodes):
        log_probs = []
        for i_step in range(length_of_episode):
            state = torch.rand([1, 7])
            a_t = torch.rand([1, 3])
            r_t = torch.rand([1, 1])

            probs, value = model(state, a_t, r_t)

            m = Categorical(probs)
            action = m.sample()
            log_probs.append(m.log_prob(action))
            action = action.item()

        print('End of episode {}'.format(i_episode))

        # Calculate loss
        log_probs = torch.cat(log_probs, dim=0)
        policy_loss = -log_probs * torch.rand([length_of_episode])
        loss = policy_loss.mean()

        # Update policy
        optimizer.zero_grad()
        loss.backward(retain_graph=True)
        optimizer.step()

        # The game ends, reset the hidden state of the GRU
        # 'index=0' stand for the finished element of the batch, and we only have one game ATM
        model.reset_hidden_state_element(index=0)  

if __name__ == "__main__":
    reinforce()

Here is the error message:

End of episode 0
End of episode 1
Warning: Error detected in AddmmBackward. Traceback of forward call that caused the error:
  File "/home/andris/Repos/github/drl-2048/src/reinforce/gru_minimal.py", line 83, in <module>
    reinforce()
  File "/home/andris/Repos/github/drl-2048/src/reinforce/gru_minimal.py", line 59, in reinforce
    probs, value = model(state, a_t, r_t)
  File "/home/andris/anaconda3/envs/Pytorch_env/lib/python3.7/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/andris/Repos/github/drl-2048/src/reinforce/gru_minimal.py", line 34, in forward
    z, self.rnn_hidden_state = self.rnn(in_put, self.rnn_hidden_state)
  File "/home/andris/anaconda3/envs/Pytorch_env/lib/python3.7/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/andris/anaconda3/envs/Pytorch_env/lib/python3.7/site-packages/torch/nn/modules/rnn.py", line 727, in forward
    self.dropout, self.training, self.bidirectional, self.batch_first)
 (print_stack at /opt/conda/conda-bld/pytorch_1587428266983/work/torch/csrc/autograd/python_anomaly_mode.cpp:60)
Traceback (most recent call last):
  File "/home/andris/Repos/github/drl-2048/src/reinforce/gru_minimal.py", line 83, in <module>
    reinforce()
  File "/home/andris/Repos/github/drl-2048/src/reinforce/gru_minimal.py", line 75, in reinforce
    loss.backward(retain_graph=True)
  File "/home/andris/anaconda3/envs/Pytorch_env/lib/python3.7/site-packages/torch/tensor.py", line 198, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/home/andris/anaconda3/envs/Pytorch_env/lib/python3.7/site-packages/torch/autograd/__init__.py", line 100, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [8, 24]], which is output 0 of TBackward, is at version 2; expected version 1 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

Thank you in advance!

P.S.: If writing down with words what I intend to do during training would help then please tell so. (I tried to write it down to begin with but gotten into too much detail and decided that maybe the code is clearer to understand.)

Edit: The tensor in the error message has the shape of [torch.FloatTensor [hidden_size, 3 * hidden_size]]

I stripped my code further down (after shifting to use GRUCell instead of regular GRU, which by itself reproduced the exact same behavior as my initial post) and managed to get a better error message from the anomaly detector, which encloses something regarding the biases of the GRU.

Code:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical


class MinimalGru(nn.Module):
    def __init__(self):
        super(MinimalGru, self).__init__()

        self.rnn = nn.GRUCell(input_size=(7), hidden_size=100)

        # Retrieve data from the GRU output in shape (batch, num_layers * num_directions,  hidden_size)
        self.policy_head = nn.Linear(100, 4)
        self.value_head = nn.Linear(100, 1)

        self.rnn_hidden_state = None

    def initialize_hidden_state(self, batch_size):
        # shape (num_layers * num_directions, batch, hidden_size)
        self.rnn_hidden_state = torch.zeros([batch_size, 100],
                                            dtype=torch.float)

    def reset_hidden_state_element(self, index):
        self.rnn_hidden_state[index] = torch.zeros([100], dtype=torch.float)

    def forward(self, state):

        self.rnn_hidden_state = self.rnn(state.clone(), self.rnn_hidden_state)

        action_scores = self.policy_head(self.rnn_hidden_state)
        value_score = self.value_head(self.rnn_hidden_state)

        return torch.softmax(action_scores, dim=1), torch.sigmoid(value_score)


def reinforce():
    torch.autograd.set_detect_anomaly(True)

    number_of_episodes = 123456
    length_of_episode = 123

    model = MinimalGru().train()
    model.initialize_hidden_state(batch_size=1)
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    for i_episode in range(number_of_episodes):

        log_probs = []
        optimizer.zero_grad()

        for i_step in range(length_of_episode):
            state = torch.rand([1, 7])
            a_t = torch.rand([1, 3])
            r_t = torch.rand([1, 1])

            probs, value = model(state)

            m = Categorical(probs)
            action = m.sample()
            log_probs.append(m.log_prob(action))
            action = action.item()

        print('End of episode {}'.format(i_episode))

        # Calculate loss
        log_probs_tensor = torch.cat(log_probs, dim=0)
        policy_loss = -log_probs_tensor * torch.rand([length_of_episode])
        loss = policy_loss.mean()

        # Update policy
        loss.backward(retain_graph=True)
        optimizer.step()

        # The game ends, reset the hidden state of the GRU AFTER the backward()
        model.reset_hidden_state_element(index=0)


if __name__ == "__main__":
    reinforce()

Error message:

End of episode 0
End of episode 1
Warning: Error detected in AddmmBackward. Traceback of forward call that caused the error:
  File "/home/andris/Repos/github/drl-2048/src/reinforce/gru_minimal.py", line 80, in <module>
    reinforce()
  File "/home/andris/Repos/github/drl-2048/src/reinforce/gru_minimal.py", line 57, in reinforce
    probs, value = model(state)
  File "/home/andris/anaconda3/envs/Pytorch_env/lib/python3.7/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/andris/Repos/github/drl-2048/src/reinforce/gru_minimal.py", line 29, in forward
    self.rnn_hidden_state = self.rnn(state.clone(), self.rnn_hidden_state)
  File "/home/andris/anaconda3/envs/Pytorch_env/lib/python3.7/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/andris/anaconda3/envs/Pytorch_env/lib/python3.7/site-packages/torch/nn/modules/rnn.py", line 1033, in forward
    self.bias_ih, self.bias_hh,
 (print_stack at /opt/conda/conda-bld/pytorch_1587428266983/work/torch/csrc/autograd/python_anomaly_mode.cpp:60)
Traceback (most recent call last):
  File "/home/andris/Repos/github/drl-2048/src/reinforce/gru_minimal.py", line 80, in <module>
    reinforce()
  File "/home/andris/Repos/github/drl-2048/src/reinforce/gru_minimal.py", line 72, in reinforce
    loss.backward(retain_graph=True)
  File "/home/andris/anaconda3/envs/Pytorch_env/lib/python3.7/site-packages/torch/tensor.py", line 198, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/home/andris/anaconda3/envs/Pytorch_env/lib/python3.7/site-packages/torch/autograd/__init__.py", line 100, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [100, 300]], which is output 0 of TBackward, is at version 2; expected version 1 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

Disabling bias in the GRUCell did not solve the problem, the error message remained the same.

Took a look at the GRUCell source and seen that if one were to pass to the cell the hidden state as None then it will initialize itself to zeros.

With this information I further changed my code so instead of manually setting the hidden state to zeros I set it to None thus relying on the implicit conversion. This seemed to solve my issue as now I am not encountering the error message.

Now I just have to come up with a solution where I can reset the hidden state manually.

Code:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical


class MinimalGru(nn.Module):
    def __init__(self):
        super(MinimalGru, self).__init__()

        self.rnn = nn.GRUCell(input_size=(7), hidden_size=100)

        # Retrieve data from the GRU output in shape (batch, num_layers * num_directions,  hidden_size)
        self.policy_head = nn.Linear(100, 4)
        self.value_head = nn.Linear(100, 1)

        self.rnn_hidden_state = None

    def forward(self, state):
        self.rnn_hidden_state = self.rnn(state, self.rnn_hidden_state)
        action_scores = self.policy_head(self.rnn_hidden_state)
        value_score = self.value_head(self.rnn_hidden_state)

        return torch.softmax(action_scores, dim=1), torch.sigmoid(value_score)


def reinforce():
    torch.autograd.set_detect_anomaly(True)

    number_of_episodes = 123456
    length_of_episode = 123

    model = MinimalGru().train()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    for i_episode in range(number_of_episodes):

        log_probs = []
        optimizer.zero_grad()

        for i_step in range(length_of_episode):
            state = torch.rand([1, 7])
            a_t = torch.rand([1, 3])
            r_t = torch.rand([1, 1])

            probs, value = model(state)

            m = Categorical(probs)
            action = m.sample()
            log_probs.append(m.log_prob(action))
            action = action.item()

        print('End of episode {}'.format(i_episode))

        # Calculate loss
        log_probs_tensor = torch.cat(log_probs, dim=0)
        policy_loss = -log_probs_tensor * torch.rand([length_of_episode])
        loss = policy_loss.mean()

        # Update policy
        loss.backward(retain_graph=True)
        optimizer.step()

        # The game ends, reset the hidden state of the GRU AFTER the backward()
        model.rnn_hidden_state = None


if __name__ == "__main__":
    reinforce()

Edit:
I tried to substitute initializing the hidden state to torch.zeros([1, 100], dtype=torch.float) instead of None and it worked as expected. From that information it seems that my initial error originates from this function:

The solution was to zero out a certain index of the hidden states by multiplying those with zero and not by replacing it with torch.zeros()