Recording gradients for each datum in a batch

I am training a small language model with batch size 100 and context length 512. My loss is the standard CrossEntropyLoss. I am using the reduction="none" parameter to get a loss tensor of shape (100, 512) for a total of 51200 entries. I am actually interested in the gradients for all of my model parameters on each datum. For a model interpretability experiment, I really do need the datum-grain gradients rather than the batch-grain gradients. In other words, on every training step, I’d like to retrieve the gradients for all p in model.parameters() that have requires_grad=True separately for each of the 100 datums in my batch, which should be about 100x my model size to store. In particular, I do not want the averaged gradient across the batch.

My current approach has been to record all gradients for every token and every datum (every element of the (100, 512) loss tensor) separately and then perform the necessary arithmetic to recover the datum-level gradients. A rough example of this is shown in the code snippet below. This has worked fine on smaller runs (3-layer models with ~1M tokens of training, a batch size of 20 and a context length of 30). However, it is taking far too long on larger models. In particular, for my current architecture I observe that one call to loss[k].backward(retain_graph=True) takes 0.03s, which means a single training step would require around 51200*0.03s = 25m to record datum-level gradients. I expect to run the training process for about 50k steps so that would require about 2.3years to run in full.

total_grads = {
    name: torch.zeros_like(param) for name, param in self.model.named_parameters()
    if name not in self.exclude_units
}
# At this point, the loss has been flattened to show data contiguously by batch.
# We iterate over each batch and context to process the gradients into datum-level gradients.
for batch in range(data.shape[0]):
    new_grads = {
        name: torch.zeros_like(param) for name, param in self.model.named_parameters()
        if name not in self.exclude_units
    }
    for token_index in range(data.shape[1]):
        k = batch * data.shape[0] + token_index
        loss[k].backward(retain_graph=True)
        for name, param in self.model.named_parameters():
            if name not in self.exclude_units:
                new_grads[name].add_(param.grad.data * -learning_rate / loss.size(0))
            total_grads[name].add_(param.grad.data)
    
        model.zero_grad()
    self.record(new_grads)

I have tried using something like this previous question to compute gradients on the loss tensor instead of the aggregated scalar when CrossEntropyLoss(reduction="mean") instead of "none". However, this does not give me what I need as it does not give gradients for the model-parameters; I believe it only captures computes at the root node of the computational graph.

Is there an efficient method to compute datum-level gradients? If I could get rid of the context length factor that would be a reduction of 512x, which corresponds to a training time of 2.3years / 512 =~ 1.6days which is acceptable. I can think of two techniques:

  1. Extend Torch’s CrossEntropyLoss with reduction="mean" and a new accompanying dim parameter (invoked here with 1) which would perform the mean reduction across a specific tensor dimension. In this case I would use it to reduce along dim=1, the context length axis. I could then invoke loss[k].backward(retain_graph=True) with k varying from 0 to 99 according to the batch_size. It seems the deprecated reduce and size_average parameters of CrossEntropyLoss come close to what I need, but not quite, as they don’t apply to a specific axis. Given also that these are deprecated, I want to check with the team before proceeding in this direction.
  2. Spinning up an array of 512 nodes with A100 GPU devices, and doing a distributed training run with each device mirroring the full training process, but running only the gradients for the kth token (loss[batch_index, k]) where k varies from 0 to 511). I could then perform a map-reduce mean on the .grads from each device to get the datum-level gradients. I could then finally map-reduce mean again to apply the final batch-grain gradient update and take an optimizer step. This approach would needless to say be pretty expensive (512x more costly in GPU hours than my current setup) and I would prefer to avoid it if it possible.

Let me know if I am missing something and there is a more efficient way to achieve this kind of datum-level gradient capture using Torch. Thank you in advance!