I am training a small language model with batch size 100 and context length 512. My loss is the standard CrossEntropyLoss. I am using the `reduction="none"`

parameter to get a loss tensor of shape `(100, 512)`

for a total of 51200 entries. I am *actually* interested in the gradients for all of my model parameters on each datum. For a model interpretability experiment, I really do need the datum-grain gradients rather than the batch-grain gradients. In other words, on every training step, I’d like to retrieve the gradients for all `p in model.parameters()`

that have `requires_grad=True`

*separately* for each of the 100 datums in my batch, which should be about 100x my model size to store. In particular, I do not want the averaged gradient across the batch.

My current approach has been to record all gradients for every token and every datum (every element of the `(100, 512)`

loss tensor) separately and then perform the necessary arithmetic to recover the datum-level gradients. A rough example of this is shown in the code snippet below. This has worked fine on smaller runs (3-layer models with ~1M tokens of training, a batch size of 20 and a context length of 30). However, it is taking far too long on larger models. In particular, for my current architecture I observe that one call to `loss[k].backward(retain_graph=True)`

takes 0.03s, which means a single training step would require around 51200*0.03s = 25m to record datum-level gradients. I expect to run the training process for about 50k steps so that would require about 2.3years to run in full.

```
total_grads = {
name: torch.zeros_like(param) for name, param in self.model.named_parameters()
if name not in self.exclude_units
}
# At this point, the loss has been flattened to show data contiguously by batch.
# We iterate over each batch and context to process the gradients into datum-level gradients.
for batch in range(data.shape[0]):
new_grads = {
name: torch.zeros_like(param) for name, param in self.model.named_parameters()
if name not in self.exclude_units
}
for token_index in range(data.shape[1]):
k = batch * data.shape[0] + token_index
loss[k].backward(retain_graph=True)
for name, param in self.model.named_parameters():
if name not in self.exclude_units:
new_grads[name].add_(param.grad.data * -learning_rate / loss.size(0))
total_grads[name].add_(param.grad.data)
model.zero_grad()
self.record(new_grads)
```

I have tried using something like this previous question to compute gradients on the loss tensor instead of the aggregated scalar when `CrossEntropyLoss(reduction="mean")`

instead of `"none"`

. However, this does not give me what I need as it does not give gradients for the model-parameters; I believe it only captures computes at the root node of the computational graph.

Is there an efficient method to compute datum-level gradients? If I could get rid of the context length factor that would be a reduction of 512x, which corresponds to a training time of 2.3years / 512 =~ 1.6days which is acceptable. I can think of two techniques:

- Extend Torch’s
`CrossEntropyLoss`

with`reduction="mean"`

and a new accompanying`dim`

parameter (invoked here with`1`

) which would perform the mean reduction across a specific tensor dimension. In this case I would use it to reduce along`dim=1`

, the context length axis. I could then invoke`loss[k].backward(retain_graph=True)`

with`k`

varying from 0 to 99 according to the batch_size. It seems the deprecated`reduce`

and`size_average`

parameters of`CrossEntropyLoss`

come close to what I need, but not quite, as they don’t apply to a specific axis. Given also that these are deprecated, I want to check with the team before proceeding in this direction. - Spinning up an array of 512 nodes with A100 GPU devices, and doing a distributed training run with each device mirroring the full training process, but running only the gradients for the kth token (
`loss[batch_index, k]`

) where k varies from 0 to 511). I could then perform a map-reduce mean on the`.grad`

s from each device to get the datum-level gradients. I could then finally map-reduce mean again to apply the final batch-grain gradient update and take an optimizer step. This approach would needless to say be pretty expensive (512x more costly in GPU hours than my current setup) and I would prefer to avoid it if it possible.

Let me know if I am missing something and there is a more efficient way to achieve this kind of datum-level gradient capture using Torch. Thank you in advance!