Loss gradient error

I’m getting the following error whenever I try to backpropagate loss:

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [128, 10, 128]], which is output 0 of SoftmaxBackward, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

Here is the code that generates the loss:

  self.batch_size = batch_size = self.args.batch_size
  feature_size = self.data_set.train.feature_size
  self.num_blocks = num_blocks = self.data_set.train.num_of_blocks
  rnn_dim = self.args.rnn_dim
  max_length = self.data_set.maxlength
  spatial_embed_dim = self.args.spatial_embed_dim

  self.sentences_pl = torch.LongTensor(np.random.randint(self.data_set.vocabsize, size=(batch_size, self.data_set.maxlength))) 
  self.lengths_pl = torch.IntTensor(np.random.randint(max_length, size=batch_size)) 

  self.worlds_pl = torch.FloatTensor(np.random.randint(spatial_embed_dim, size=(batch_size, num_blocks, 3))) 

  self.features_pl = torch.FloatTensor(np.random.randint(feature_size, size=(batch_size, num_blocks, feature_size))) 

  self.source_pl = torch.LongTensor(np.random.randint(num_blocks, size=batch_size)) 
  self.target_pl = torch.FloatTensor(np.random.randint(num_blocks, size=(batch_size, 3))) 

  self.keep_prob = 1 - self.args.dropout 
  self.global_step = 0                                             

  if self.args.embed_dim == -1:       # Default is -1 now   
      if initial:
          print("Model: Using hot Vector with vocabsize", self.data_set.vocabsize)
      embeddings = torch.FloatTensor(np.identity(self.data_set.vocabsize)) 
      input_dim = self.data_set.vocabsize
  else:
      if initial:
          print("Model: Using word embed with embed size", self.args.embed_dim)
      embeddings = torch.FloatTensor(np.random.uniform(-1, 1, size=(self.data_set.vocabsize, self.args.embed_dim))) 
      input_dim = self.args.embed_dim
  # b, lens -> b, lens, embed
  word_embedded = F.embedding(self.sentences_pl, embeddings)

  # RNN
  if self.args.GRU:
      self.forward = nn.GRU(input_size=input_dim, hidden_size=rnn_dim, dropout=1-self.keep_prob)
      hidden = torch.randn(1, word_embedded.shape[1], batch_size*2)
      nn.init.xavier_uniform_(hidden, gain=nn.init.calculate_gain('relu'))
      rnn_outputs, h_t = self.forward(word_embedded, hidden)
  else:
      self.forward = nn.LSTM(input_size=input_dim, hidden_size=rnn_dim, dropout=1-self.keep_prob)
      hidden = torch.randn(1, word_embedded.shape[1], batch_size*2)
      cell = torch.randn(1, word_embedded.shape[1], batch_size*2)
      nn.init.xavier_uniform_(hidden, gain=nn.init.calculate_gain('relu'))
      nn.init.xavier_uniform_(cell, gain=nn.init.calculate_gain('relu'))
      rnn_outputs, (h_t, c_t) = self.forward(word_embedded, (hidden, cell))

  # Get the last fo of RNN
  index = torch.arange(0, batch_size) * max_length + (self.lengths_pl - 1) 
  flat = torch.reshape(rnn_outputs, (-1, rnn_dim)) 
  rnn_fo = flat[index] 

  sentence_embedded = rnn_outputs

  spatial_embedded = self.features_pl
  spatial_dim = feature_size
  block_outer_attention_layer = None
  block_inner_attention_layer = None

  self.optimizer = optim.Adam(self.forward.parameters(), lr=self.args.lr)

  target_task_attvecs = rnn_outputs

  if self.args.wordCNN is not None:
      target_ref_attvecs = self.cnn_layer(sentence_embedded)
      target_ref_attvecs = torch.unsqueeze(target_ref_attvecs, 1).repeat(1, num_blocks, 1)
      if initial:
          print("Model: Size of traget ref attvecs is ", list(target_ref_attvecs.shape))
      target_block_logits = InnerAttention(initial, target_ref_attvecs, spatial_embedded).setup()
      # if block_inner_attention_layer is None:
      #     block_inner_attention_layer = InnerAttention(target_ref_attvecs, spatial_embedded,
      #                                                  shape2=[batch_size, num_blocks, spatial_dim])
      # target_block_logits = block_inner_attention_layer.setup(target_ref_attvecs, spatial_embedded)
  else:
      if block_outer_attention_layer is None: # Handle share parameter
          block_outer_attention_layer = OuterAttention(target_task_attvecs, spatial_embedded)
      target_block_outer_attvecs = block_outer_attention_layer.setup(target_task_attvecs, spatial_embedded,
                                                                      reduced_dim=1)
      if block_inner_attention_layer is None:
          block_inner_attention_layer = InnerAttention(initial, target_block_outer_attvecs, spatial_embedded)
      target_block_logits = block_inner_attention_layer.setup(target_block_outer_attvecs, spatial_embedded)

  self.target_block_logits = target_block_logits

  self.target_block_probs = target_block_probs = F.softmax(self.target_block_logits)
  self.target_block_probs_stddev = get_stddev(target_block_probs)

  if self.args.no_offset:
      if initial:
          print("Model: Warning: Do not use offset, use block Only")
      # predict_position = ref
      self.offset = offset = torch.zeros((batch_size, 3))
  else:
      if self.args.discrete_offset is not None:
          if initial:
              print("Model: Using weighted directions Offset")
          # Get the vector of sentence
          if self.args.wordCNN is not None:
              if initial:
                  print("\t with CNN Model")
              offset_attvec = self.cnn_layer(sentence_embedded)
          else:
              if initial:
                  print("\t with IndependentAttention Layer")
              offset_attvec, self.offset_attention = IndependentAttention(target_task_attvecs).setup(
                  reduced=True)

          self.nDirs = self.args.discrete_offset
          if initial:
              print("\t with directions ", self.nDirs)
          if self.args.fixedDirs:
              if initial:
                  print("\t with fix directions")
              self.directions = self.create_directions(self.nDirs, initial, length=0.2)
          else:
              self.directions = Variable(
                  torch.fmod(torch.normal(0.0, std=1.0 / math.sqrt(float(self.nDirs)), size=(self.nDirs, 3)), 2)
              )

          sentence_dim = list(offset_attvec.shape)[1]
          weights = Variable(
              torch.fmod(torch.normal(0.0, std=1.0 / math.sqrt(float(sentence_dim)), size=(sentence_dim, self.nDirs)), 2)
          )
          biases = biases = Variable(torch.zeros(self.num_of_directions))
          # b,dim * dim, 9 = b, 9
          self.weight_directions = torch.matmul(offset_attvec, weights)  # + biases

          if self.args.softmaxOffset:
              if initial:
                  print("\t with softmax of weight directions and BIASES")
              self.weight_directions = self.weight_directions + biases
              self.weight_directions = F.softmax(self.weight_directions)


          if self.args.probOffset:
              if initial:
                  print("\t with sampling offset from logits")
              if self.args.softmaxOffset:
                  raise NameError("The prob Offset is confilict with softmax Offset")
              self.weight_directions = self.weight_directions + biases
              choice = torch.multinomial(self.weight_directions, 1)
              choice = torch.squeeze(choice)
              choice = choice.int()
              offset = self.directions[choice]
          else:
              # b, 9 * 9, 3 = b, 3
              offset = torch.matmul(self.weight_directions, self.directions)
      else:
          if initial:
              print("Model.offset: Use two FCNs offset")
          if self.args.wordCNN is not None:
              offset_attvec = self.cnn_layer(sentence_embedded)
          else:
              offset_attvec, self.offset_attention = IndependentAttention(target_task_attvecs).setup(
                  reduced=True)
          offset_fcn0 = FCN(offset_attvec, 64).setup(initial, nonlinear="sigmoid")
          offset = FCN(offset_fcn0, 3, bias=False).setup(initial, nonlinear="")
      self.offset = offset

  self.select_block = torch.argmax(self.target_block_logits, dim=1).int()
  self.ref = ref = fetch_by_index(self.worlds_pl, self.select_block)
  self.get_target_loss()

  self.loss = self.find_loss(initial)
  self.tl_grad = grad(self.loss, self.target_block_logits)  
def find_loss(self, initial=False):
        self.loss = self.get_target_loss()

        self.loss = torch.sum(self.loss.clone())
        ####### Regularizations ####### 
        reg_losses = torch.tensor(REGULARIZATION_LOSSES)
        reg_weight = 0.0005
        self.loss = self.loss.clone() + reg_weight * sum(reg_losses)
        return self.loss
def get_target_loss(self):
        if self.args.varOffset:
            if self.args.softmaxOffset or self.args.probOffset:
                raise NameError("varOffset is incompatible with softmax offset and prob offset")
            rl_3d = torch.unsqueeze(self.target_block_logits, -1)
            ol_3d = torch.unsqueeze(self.weight_directions, 1)
            # b,n,1 + b,1,o -> b,n,o
            joint_logits = rl_3d + ol_3d

            # b,n,3 -> b,n,1,3
            block_position_4d = torch.unsqueeze(torch.tensor(self.worlds_pl), 2)
            # (b, n, 1, 3) +  (*1, *1, o, 3) = (b, n, o, 3)
            predict_positions_4d = block_position_4d + self.directions
            # (b, 3) -> (b, 1, 1, 3)
            target_4d = torch.unsqueeze(torch.unsqueeze(torch.tensor(self.target_pl), 1), 1)
            # (b, n, o, 3) -> (b, n, o)
            dis2 = torch.sum(torch.pow(predict_positions_4d - target_4d, 2), 3)

            joint_loss = nn.Softmax()
            self.target_loss = joint_loss(joint_logits)
            # joint_probs = F.softmax(joint_logits)
            # self.target_loss = torch.sum(torch.mul(self.target_loss, dis2), (1,2))
            shape = (self.batch_size, self.num_blocks * self.nDirs)
            joint_logits_2d = torch.reshape(joint_logits, shape)
            predict_position_3d = torch.reshape(predict_positions_4d, shape+(3,))
            select = torch.argmax(joint_logits_2d, dim=1).int()

            # select_offset = tf.cast(tf.argmax(offse), tf.int32)
            # select_ = self.select_block * num_blocks + select_offset
            # with tf.control_dependencies([tf.assert_equal(select, select_)]):
            self.predict_position = fetch_by_index(predict_position_3d, select)
        else:
            predict_position_3d = torch.tensor(self.worlds_pl) + torch.unsqueeze(self.offset, 1)
            target_3d = torch.unsqueeze(torch.tensor(self.target_pl), 1)
            dis2 = torch.sum(torch.pow(predict_position_3d - target_3d, 2), 2)
            self.target_loss = torch.sum(torch.mul(self.target_block_probs, dis2), 1)

            self.predict_position = torch.tensor(self.ref + self.offset)

        # || ref - target ||^2
        ref_dis2 = torch.sum(torch.pow(self.ref - torch.tensor(self.target_pl), 2), 1)
        target_3d = torch.unsqueeze(torch.tensor(self.target_pl), 1)
        # (b, n, 3) - (b, 1, 3) -> b, n, 3 -> b, n
        blocks_dis2 = torch.sum(torch.pow(torch.tensor(self.worlds_pl) - target_3d, 2), 2)
        # b, n * b, n -> b, n -> b
        # \sum || b_i - target || ^2 * p_i
        weighted_ref_dis2 = torch.sum(torch.mul(self.target_block_probs, blocks_dis2), 1)

        gamma = 0.1

        if self.args.targetLoss == "ref":
            self.target_loss += gamma * ref_dis2
        elif self.args.targetLoss == "weighted":
            self.target_loss += gamma * weighted_ref_dis2
        return self.target_loss

If anyone could provide advice or tips for troubleshooting this issue, I would appreciate it.

You could rerun the code with torch.autograd.set_detect_anomaly(True) as suggested in the error message to get a hopefully helpful stack trace.
I cannot find the offending operation by skimming through the code.

@ptrblck So I reran the code after including torch.autograd.set_detect_anomaly(True), and this is the stack trace that I got as a result:

[W python_anomaly_mode.cpp:60] Warning: Error detected in SoftmaxBackward. Traceback of forward call that caused the error: 
File "pytorch_run.py", line 2437, in <module>
  joint_run() 
File "pytorch_run.py", line 2425, in joint_run
  trainer = Trainer(args) File "pytorch_run.py", line 1787, in __init__ model.setup() 
File "pytorch_run.py", line 1708, in setup 
  self.loss = self.find_loss(initial) 
File "pytorch_run.py", line 1450, in find_loss 
  self.loss = self.get_source_loss() + self.get_target_loss() 
File "pytorch_run.py", line 1409, in get_target_loss 
  self.target_loss = joint_loss(joint_logits) 
File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 722, in _call_impl 
  result = self.forward(*input, **kwargs) 
File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/activation.py", line 1140, in forward 
  return F.softmax(input, self.dim, _stacklevel=5) 
File "/opt/conda/lib/python3.7/site-packages/torch/nn/functional.py", line 1498, in softmax 
  ret = input.softmax(dim) 
(function print_stack)

This error trace seems to indicate that the variable joint_logits is being modified in-place and is thus causing the problem, but I’m not sure why. As you can see in get_target_loss, it doesn’t seem to be modified in-place. I would appreciate it if you could provide your insights on this error.

I also don’t know which line causes this error. You could add .clone() to tensors in the potentially offending method and narrow down the offending line of code.