Auxiliary Loss with Gradient Checkpointing in LLMs

I am trying to add an auxiliary loss to an LLM (LLamaModel based on hgf/transformers library), which uses the gradient/activation checkpointing tool from pytorch under the hood. That is, each decoder block is wrapped in a individual checkpoint.

My problem now is that I want to add an auxiliary loss calculated in each block from inside the block to the loss. So far I did this by storing the auxiliary loss on each block for each forward pass, i.e., I have a property “aux_loss” on each block to store the values for the current forward. Then after computing the loss, I gather all losses by iterating over the decoder block, sum them and add them to the loss.

This is currently not working, and I would like to know if it is because of the gradient checkpointing that gradients are not properly backpropagated. My goal right now is to find out if this is the root cause and how to fix it, or if I need to look elsewhere. In addition, if at all possible, I would like to observe what actually happens during reversing to understand it better and would like to find out if there is a tool for this.

I appreciate any help on this, since I have been stuck on this for a couple of days now. Thank you.
Since there is a lot of code involved it’s difficult to share all code here, however, I am happy to share specific code snippets if desired.

That’s an interesting use case and while I don’t know what exactly fails, I would assume you are seeing some errors during the backward call?
If your aux loss calculation working if no checkpointing is used?

I tried it w/o before, but then I got an OOM error. I did not get any errors during the backward call, but the aux loss did not improve. However, I think I can try this with a dummy model to see if this is the problem.

As for now, do you know if the pytorch checkpoint utility relies on the outputs of the forward to properly do backward? My concern is that, by setting this aux loss “outside” of the model flow, it might be undetected.

I tried another architecture first. Which is basically the other way round, instead of the blocks setting aux loss, they use the attribute “scores” for routing inputs to experts, which is set by a router layer in the first block.

Architecture (see setup 2):

  • Block 1:
    • x → original forward → out (block output)
    • x → router → scores (set on the other blocks as attribute)
  • BLock 2 & 3:
    • x → original forward → out
    • x → experts → scores * experts → sum of weighted expert outputs (exp_outs)
      => out + exp_outs (block output)

Code with different setups can be found here:

To my amazement, both work, but only with use_reentrant=False. However, in the transformers library the checkpoint utility is used with use_reentrant=True. I tried it with the flag enabled, but this caused the following error (I don’t get this error with the transformers library).

Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
            tensors, grad_tensors_, retain_graph, create_graph, inputs,
            allow_unreachable=True, accumulate_grad=True)  # Calls into the C++ engine to run the backward pass
E       RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

(I searched for it but only found more general posts, and I don’t see how they are related to this problem, e.g. RuntimeError: element 0 of variables does not require grad and does not have a grad_fn)

Do you know why this is happening in this case? Also, I do not fully understand this flag. What is it doing exactly, when enabled? And when should I enable it?

I have already checked various tutorials/repos/docs but afaik they do not go into detail about this option (broke the links on purpose here because i can’t add more than two per post):

  • Pytorch Docs
  • github/cybertronai/gradient-checkpointing
  • qywu.github explore-gradient-checkpointing
  • github/rasbt/deeplearning-models/blob/master/pytorch_ipynb/mechanics/gradient-checkpointing-nin.ipynb

Update:

Found this post of yours (Checkpoint with no grad requiring inputs PROBLEM - #9 by ptrblck) which explains why backward breaks when all inputs don’t require grad.

I updated the code accordingly, I am just setting requires_grad=True for x and use_reentrant=True.

Updated Gist:

This works for the version, where the scores are explicitly passed to the block’s forward method but not when setting the scores as attribute. Here I get the following error:

 Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
            tensors, grad_tensors_, retain_graph, create_graph, inputs,
            allow_unreachable=True, accumulate_grad=True)  # Calls into the C++ engine to run the backward pass
E       RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.