Auxiliary Loss with Gradient Checkpointing in LLMs

Update:

Found this post of yours (Checkpoint with no grad requiring inputs PROBLEM - #9 by ptrblck) which explains why backward breaks when all inputs don’t require grad.

I updated the code accordingly, I am just setting requires_grad=True for x and use_reentrant=True.

Updated Gist:

This works for the version, where the scores are explicitly passed to the block’s forward method but not when setting the scores as attribute. Here I get the following error:

 Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
            tensors, grad_tensors_, retain_graph, create_graph, inputs,
            allow_unreachable=True, accumulate_grad=True)  # Calls into the C++ engine to run the backward pass
E       RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.