Errors in backward when using FSDP -- RuntimeError: setStorage: sizes [32000, 4096], strides [4096, 1], storage offset 131076096, and itemsize 2 requiring a storage size of 524296192 are out of bounds for storage of size 0

Hi. I am fine-tuning llama-7b model on 4 A100 GPUs, utilizing FSDP and LoRA (peft) based on the llama-recipes library (GitHub - facebookresearch/llama-recipes: Examples and recipes for Llama 2 model).

Torch version: 2.2.0+cu118

My goal is to compute the “per-token-gradient” for each parameter for analysis purposes.

Here’s how I’ve modified the training code:

  1. I modified the model code to return non-averaged loss by setting the reduction in the CrossEntropyLoss class as ‘none’.

  2. Instead of calling backward on the averaged loss, I now call backward on the loss for “each individual” token.

In below code, I changed this section:

to:

loss[0].backward(retrain_graph=True)
optimizer.zero_grad()
loss[1].backward()

However, this modification results in an error:

Traceback (most recent call last):
  File ".../lib/python3.8/site-packages/torch/autograd/__init__.py", line 411, in grad
    result = Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: setStorage: sizes [32000, 4096], strides [4096, 1], storage offset 131076096, and itemsize 2 requiring a storage size of 524296192 are out of bounds for storage of size 0

I also attempted torch.autograd.grad as follows:

grad1 = torch.autograd.grad(loss[0], updated_parameters, retain_graph=True, allow_unused=True)
optimizer.zero_grad()
grad2 = torch.autograd.grad(loss[1], updated_parameters, allow_unused=True)

but it led to the same error.

Considering I’m using FSDP and LoRA, this seems to be a complex issue.
One possible solution might be to perform forward computation for each token individually instead of the entire sequence, but this could be very time-consuming.

Additionally, I’m curious about the cause of this problem.
Does anyone have any insights?

1 Like

Same problem. How did you fix it? Thanks.