Auxiliary Loss with Gradient Checkpointing in LLMs

I tried another architecture first. Which is basically the other way round, instead of the blocks setting aux loss, they use the attribute “scores” for routing inputs to experts, which is set by a router layer in the first block.

Architecture (see setup 2):

  • Block 1:
    • x → original forward → out (block output)
    • x → router → scores (set on the other blocks as attribute)
  • BLock 2 & 3:
    • x → original forward → out
    • x → experts → scores * experts → sum of weighted expert outputs (exp_outs)
      => out + exp_outs (block output)

Code with different setups can be found here:

To my amazement, both work, but only with use_reentrant=False. However, in the transformers library the checkpoint utility is used with use_reentrant=True. I tried it with the flag enabled, but this caused the following error (I don’t get this error with the transformers library).

Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
            tensors, grad_tensors_, retain_graph, create_graph, inputs,
            allow_unreachable=True, accumulate_grad=True)  # Calls into the C++ engine to run the backward pass
E       RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

(I searched for it but only found more general posts, and I don’t see how they are related to this problem, e.g. RuntimeError: element 0 of variables does not require grad and does not have a grad_fn)

Do you know why this is happening in this case? Also, I do not fully understand this flag. What is it doing exactly, when enabled? And when should I enable it?

I have already checked various tutorials/repos/docs but afaik they do not go into detail about this option (broke the links on purpose here because i can’t add more than two per post):

  • Pytorch Docs
  • github/cybertronai/gradient-checkpointing
  • qywu.github explore-gradient-checkpointing
  • github/rasbt/deeplearning-models/blob/master/pytorch_ipynb/mechanics/gradient-checkpointing-nin.ipynb