Backpropagating multiple losses

Hello,
I am training model 1 (using train1) with a specific loss function that involves tensor A. I am accumulating the loss and then want to perform an update. Next I am training a second model 2 (train2) in which I want to calculate the gradients wrt A using the loss calculated in train2. Thus I am adding loss 1 to loss2.

#reproduce error
from transformers import BertModel, BertForMaskedLM, BertConfig, EncoderDecoderModel
import torch
import torch.nn.functional as F
model1 = EncoderDecoderModel.from_encoder_decoder_pretrained('bert-base-uncased', 'bert-base-uncased') # initialize Bert2Bert from pre-trained checkpoints
model2 = EncoderDecoderModel.from_encoder_decoder_pretrained('bert-base-uncased', 'bert-base-uncased') # initialize Bert2Bert from pre-trained checkpoints


optimizer1 = torch.optim.Adam(model1.parameters(), lr=0.001)
A=torch.rand(1, requires_grad=True)
optimizer3 = torch.optim.SGD([A], lr=0.1)

en_input=torch.tensor([[1,2], [3,4]])
en_masks=torch.tensor([[0,0], [0,0]])
de_output=torch.tensor([[3,1], [4,2]])
de_masks=torch.tensor([[0,0], [0,0]])
lm_labels=torch.tensor([[5,7], [6,8]])

torch.autograd.set_detect_anomaly(True)

def train1():
  acc=torch.zeros(1)
  for i in range(2):
    optimizer1.zero_grad()
    out = model1(input_ids=en_input, attention_mask=en_masks, decoder_input_ids=de_output, 
                        decoder_attention_mask=de_masks, labels=lm_labels.clone())
          

    prediction_scores = out[1]
    predictions = F.log_softmax(prediction_scores, dim=2)
    p=((predictions.sum() - de_output.sum())*A).sum()
    p=torch.unsqueeze(p, dim=0)
    acc = torch.cat((p,acc)) # accumulating the loss 

  loss=acc.sum()
  A.retain_grad()
  loss.backward(retain_graph=True) 
  optimizer1.step() 
  return loss


def train2(loss1):
for i in range (2):
   optimizer3.zero_grad()
   output = model2(input_ids=en_input, attention_mask=en_masks, 
                               decoder_input_ids=de_output, 
                      decoder_attention_mask=de_masks, labels=lm_labels.clone())
        
   prediction_scores_ = output[1]
   predictions_= F.log_softmax(prediction_scores_, dim=2)
   loss2=((predictions_.sum() - de_output.sum())).sum()+loss1 # want to calculate gradients 
 wrt A
   A.retain_grad()
   loss2.backward(inputs=[A], retain_graph=True) 
   optimizer3.step() #update A based on calculated gradients

loss1=train1()
train2(loss1)

If this is the right method, I am not understanding whats wrong in my code? If its not right, I would appreciate if someone pointed me in the right direction.

error trace

/usr/local/lib/python3.7/dist-packages/torch/autograd/__init__.py:147: UserWarning: Error detected in MulBackward0. Traceback of forward call that caused the error:
  File "/usr/lib/python3.7/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/usr/lib/python3.7/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py", line 16, in <module>
    app.launch_new_instance()
  File "/usr/local/lib/python3.7/dist-packages/traitlets/config/application.py", line 845, in launch_instance
    app.start()
  File "/usr/local/lib/python3.7/dist-packages/ipykernel/kernelapp.py", line 499, in start
    self.io_loop.start()
  File "/usr/local/lib/python3.7/dist-packages/tornado/platform/asyncio.py", line 132, in start
    self.asyncio_loop.run_forever()
  File "/usr/lib/python3.7/asyncio/base_events.py", line 541, in run_forever
    self._run_once()
  File "/usr/lib/python3.7/asyncio/base_events.py", line 1786, in _run_once
    handle._run()
  File "/usr/lib/python3.7/asyncio/events.py", line 88, in _run
    self._context.run(self._callback, *self._args)
  File "/usr/local/lib/python3.7/dist-packages/tornado/platform/asyncio.py", line 122, in _handle_events
    handler_func(fileobj, events)
  File "/usr/local/lib/python3.7/dist-packages/tornado/stack_context.py", line 300, in null_wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/zmq/eventloop/zmqstream.py", line 451, in _handle_events
    self._handle_recv()
  File "/usr/local/lib/python3.7/dist-packages/zmq/eventloop/zmqstream.py", line 480, in _handle_recv
    self._run_callback(callback, msg)
  File "/usr/local/lib/python3.7/dist-packages/zmq/eventloop/zmqstream.py", line 434, in _run_callback
    callback(*args, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/tornado/stack_context.py", line 300, in null_wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/ipykernel/kernelbase.py", line 283, in dispatcher
    return self.dispatch_shell(stream, msg)
  File "/usr/local/lib/python3.7/dist-packages/ipykernel/kernelbase.py", line 233, in dispatch_shell
    handler(stream, idents, msg)
  File "/usr/local/lib/python3.7/dist-packages/ipykernel/kernelbase.py", line 399, in execute_request
    user_expressions, allow_stdin)
  File "/usr/local/lib/python3.7/dist-packages/ipykernel/ipkernel.py", line 208, in do_execute
    res = shell.run_cell(code, store_history=store_history, silent=silent)
  File "/usr/local/lib/python3.7/dist-packages/ipykernel/zmqshell.py", line 537, in run_cell
    return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/IPython/core/interactiveshell.py", line 2718, in run_cell
    interactivity=interactivity, compiler=compiler, result=result)
  File "/usr/local/lib/python3.7/dist-packages/IPython/core/interactiveshell.py", line 2822, in run_ast_nodes
    if self.run_code(code, result):
  File "/usr/local/lib/python3.7/dist-packages/IPython/core/interactiveshell.py", line 2882, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-66-c603f915c713>", line 78, in <module>
    loss1=train1()
  File "<ipython-input-66-c603f915c713>", line 25, in train1
    p=((predictions.sum() - de_output.sum())*A).sum()
 (Triggered internally at  /pytorch/torch/csrc/autograd/python_anomaly_mode.cpp:104.)
  allow_unreachable=True, accumulate_grad=True)  # allow_unreachable flag
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-66-c603f915c713> in <module>()
     77 for i in range(2):
     78   loss1=train1()
---> 79   train2(loss1)

2 frames
<ipython-input-66-c603f915c713> in train2(loss1)
     69     print(A.grad)
     70     #loss2.grad(inputs=A,outputs=A, only_inputs=True)
---> 71     loss2.backward(inputs=[A],retain_graph=True) #calculates gradients # retain_graph=True #list(dec.parameters())
     72     print(A.grad)
     73     # torch.nn.utils.clip_grad_norm_(model1.parameters(), 1.0)

/usr/local/lib/python3.7/dist-packages/torch/tensor.py in backward(self, gradient, retain_graph, create_graph, inputs)
    243                 create_graph=create_graph,
    244                 inputs=inputs)
--> 245         torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
    246 
    247     def register_hook(self, hook):

/usr/local/lib/python3.7/dist-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    145     Variable._execution_engine.run_backward(
    146         tensors, grad_tensors_, retain_graph, create_graph, inputs,
--> 147         allow_unreachable=True, accumulate_grad=True)  # allow_unreachable flag
    148 
    149 

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [1]] is at version 1; expected version 0 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

Based on the description it seems you are trying to use stale intermediate activations to calculate the gradients for already updated parameters, which would raise this error.
This post explains the issue in more detail using a GAN training approach.

hey @ptrblck,
Thank you so much for replying :slight_smile:. I understood my mistake. I had another ques, Is it possible for me to calculate the gradients wrt A, if I don’t add loss1 to loss2 and simply do a loss2.backward(inputs=[A]) ? Thanks.

That might be possible, as it seems A is used in the loss calculation and might not be using the aforementioned stale activations. In any case, you could just run the code and see, if Autograd would raise an error.

Well, I did try running the code without it. There is no change in the gradients of A.