Problems training a model once, the interpreter says that a tensor required for the backpropagation was modified but i dont see where

Hello im trying to build a d3qn agent but i have problems when i try to train the model.4
For some reason in the first iteration there is no problem but in the second one it gives me the next error:

    264     # some Python versions print out the first line of a multi-line function
    265     # calls in the traceback and some print out the last line
--> 266     Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    267         tensors,
    268         grad_tensors_,

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [448, 112]] is at version 2; expected version 1 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

Here are the functions that I build for training the model

  def train_batch(self, batch):
    i=0
    with torch.autograd.detect_anomaly():

      for indicator,winned,hidden, action, reward, next_indicator,next_winned, done in batch:
        print(i)

        self.train_once(indicator,winned,hidden,action,reward,next_indicator,next_winned,done)
  # esto sera utilizado para el train short memory y train long memory
  def train_once(self,indicator,winned,hidden,action,reward,next_indicator,next_winned,done):

    indicator = torch.tensor(indicator)
    winned = torch.tensor(winned)
    next_indicator = torch.tensor(next_indicator)
    next_winned = torch.tensor(next_winned)

    pred, hidden = self.model(indicator, winned, hidden)
    target = pred.clone()
    
    with torch.no_grad():
        Q_new = reward
        if not done:
            out, _ = self.model(next_indicator, next_winned, hidden)
            Q_new = reward + self.gamma * torch.max(out)
    
    target[0][action] =torch.tensor( Q_new)
    self.optimizer.zero_grad()
    
    loss = F.mse_loss(pred,target)
    loss.backward()
    self.optimizer.step()

Here is how the model is

class Model(nn.Module):

  def __init__(self,interval_input,number_indicators,kernel_length=8,kernels=16,lstm_num_layers=40,custom_lstm_size=False,lstm_hidden_length_custom=80):

    super(Model,self).__init__()
    length,convolutional_layers=get_number_cnn(interval_input,kernel_length)
    #esto sera para la lstm es importante
    lstm_hidden_length= lstm_hidden_length_custom if custom_lstm_size else length*kernels+number_indicators*2+20
    #esto lo guardaremos para cuando necesitemos calcular ciertas cosas
    self.lstm_num_layers=lstm_num_layers
    self.lstm_hidden_length=lstm_hidden_length
    #definimos la red neuronal convolucional
    self.conv=nn.Sequential(nn.Conv1d(number_indicators,kernels,kernel_length),
                            nn.ReLU(),
                            *([nn.Conv1d(kernels,kernels,kernel_length),nn.ReLU()]*(convolutional_layers)),
                            )


    self.lstm=nn.LSTM(length*kernels+number_indicators*2+2,lstm_hidden_length,num_layers=lstm_num_layers)
    #difference for improving its learning
    self.value=nn.Sequential( 
        nn.Linear(in_features=int(lstm_hidden_length),out_features=int(lstm_hidden_length*0.7)),
        nn.ReLU(),
        nn.Linear(in_features=int(lstm_hidden_length*0.7),out_features=int(lstm_hidden_length*0.7*0.7)),
        nn.ReLU(),
        nn.Linear(in_features=int(lstm_hidden_length*0.7*0.7),out_features=1)
        )
    
    #number of actions
    self.adv=nn.Sequential( 
        nn.Linear(in_features=int(lstm_hidden_length),out_features=int(lstm_hidden_length*0.7)),
        nn.ReLU(),
        nn.Linear(in_features=int(lstm_hidden_length*0.7),out_features=int(lstm_hidden_length*0.7*0.7)),
        nn.ReLU(),
        nn.Linear(in_features=int(lstm_hidden_length*0.7*0.7),out_features=3)
        )

  def empty_lstm_attributes(self):
    h0 = torch.zeros(self.lstm_num_layers, self.lstm_hidden_length,dtype=torch.float32)
    c0 = torch.zeros(self.lstm_num_layers, self.lstm_hidden_length,dtype=torch.float32)
    return (h0,c0)

  def forward(self,indicators,winned,hidden,future=0):
    output=self.conv(indicators)
    output=torch.cat(
        (
          winned, #i join the two inputs for the lstm
        output.view(output.size(0), -1)
        ), #here is the result of the convolutional neural network
         dim=-1)
    output,new_hidden=self.lstm(output,hidden)
    adv=self.adv(output)
    value=self.value(output)
    q = value + adv - adv.mean()

    return (q,new_hidden)

Hi Ranon!

I am going to speculate as follows:

The first time through, hidden doesn’t depend on the parameters of lstm.
However, the second time through the (old) hidden depends on the parameters
of lstm before they were updated by optimizer.step().

In general, optimizer.step() modifies inplace the parameters of the model
being optimized.

When you call .backward() the second time, loss, through target, depends
on the old hidden, which depends on the old lstm parameters, so you get
the inplace-modification error.

Would it make sense for your use case (assuming that this is the cause of your
error) to .detach() the old hidden from the computation graph? You would
still pass the old hidden into model so the new pred and hidden would
depend on the old hidden’s values, but those values would be .detach()ed
from their dependency on the old values of lstm’s parameters.

Aside from this speculation, please see the suggestions for locating and fixing
inplace-modification errors given in the following post:

Good luck!

K. Frank

1 Like

It works now, thank you frank god bless you.