We are looking at training transformers with both long sequences and large vocabs. When you do so, there is memory pressure specifically around the loss, since you materialize the full logits tensor of size (seqlen, vocab)
(assuming batch=1) – and furthermore, you end up with a few tensors of this size, due to the usual cast-to-fp32 as well as the activation grads for those two tensors.
One option (discussed in Fused Linear and Cross-Entropy Loss `torch.nn.functional.linear_cross_entropy` · Issue #124480 · pytorch/pytorch · GitHub) is to fuse the final linear projection and the loss function. A simpler alternative is to chunk the last hidden activations (before the final linear projection) and compute the loss per-chunk. Doing so requires breaking the autograd graph but is otherwise straightforward. In almost-pytorch pseudocode:
# h: Tensor of size (seqlen, hidden) -- these are final hidden activations before last linear
# final_linear: Module for the last linear that projects hidden -> vocab
# targets: Tensor of size (seqlen) with the true labels
h_detached = h.detach().requires_grad_()
total_loss = torch.zeros([], device="cuda")
for start, end in chunks:
logits = final_linear(h_detached[start:end, ...]).float()
loss = loss_fn(logits, targets[start:end])
total_loss += loss.detach()
loss.backward()
# Now continue backprop where you left off
h.backward(h_detached.grad)
# Note that we _haven't_ normalized the loss by len(chunks), which you likely want to do.
This works great as far as it goes. What is tricky is that most real codebases have code that broadly looks like:
loss = loss_fn(model(x), y)
loss = loss * some_factor # or divide, or mutate in some other way
loss.backward()
where loss
is a scalar Tensor. This relies on a single, unbroken autograd graph from the loss to all the leaves.
My sense is that you could map the code above onto the pattern below (“loss
is a scalar Tensor”) using a Tensor subclass. That subclass tracks the extra state (the Tensor h
to continue backprop from and the Tensor h_detached.grad
that represents so-far-accumulated backprop up to that point) and hooks into backprop to make loss.backward()
work. But I can’t figure out from the docs where to begin / which pattern of Tensor subclassing this maps onto. Tagging @albanD for tensor subclass value. Thanks!