Problem when finetuning LLM using opacus

I tried to finetune a LLM model (distlgpt2 in huggingface.co) using opacus. I just do anything as normal but get an unexpected error.

Traceback (most recent call last):
File “/u/nkp2mr/kaic/dp_diffusion_synthesis/LLM/run_llm.py”, line 25, in
model.fit(data_path)
File “/u/nkp2mr/kaic/dp_diffusion_synthesis/./LLM/great.py”, line 192, in fit
self.model = model_trainer.train(great_ds, ‘finetune’)
File “/u/nkp2mr/kaic/dp_diffusion_synthesis/./LLM/great_trainer.py”, line 106, in train
self.optimizer.step()
File “/u/nkp2mr/.local/lib/python3.10/site-packages/torch/optim/lr_scheduler.py”, line 130, in wrapper
return func.get(opt, opt.class)(*args, **kwargs)
File “/u/nkp2mr/.local/lib/python3.10/site-packages/opacus/optimizers/optimizer.py”, line 513, in step
if self.pre_step():
File “/u/nkp2mr/.local/lib/python3.10/site-packages/opacus/optimizers/optimizer.py”, line 494, in pre_step
self.clip_and_accumulate()
File “/u/nkp2mr/.local/lib/python3.10/site-packages/opacus/optimizers/optimizer.py”, line 404, in clip_and_accumulate
per_sample_norms = torch.stack(per_param_norms, dim=1).norm(2, dim=1)
RuntimeError: stack expects each tensor to be equal size, but got [13] at entry 0 and [1] at entry 1

I print the model structure:
GPT2LMHeadModel(
(transformer): GPT2Model(
(wte): Embedding(50257, 768)
(wpe): Embedding(1024, 768)
(drop): Dropout(p=0.1, inplace=False)
(h): ModuleList(
(0-5): 6 x GPT2Block(
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(attn): GPT2SdpaAttention(
(c_attn): Conv1D()
(c_proj): Conv1D()
(attn_dropout): Dropout(p=0.1, inplace=False)
(resid_dropout): Dropout(p=0.1, inplace=False)
)
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): GPT2MLP(
(c_fc): Conv1D()
(c_proj): Conv1D()
(act): NewGELUActivation()
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
(ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
(lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

Also, I print the grad sample shape for each layer, it looks like that a embedding layer has a wrong shape:
params shape: torch.Size([50257, 768]), params grad_sample shape: torch.Size([13, 50257, 768])
params shape: torch.Size([1024, 768]), params grad_sample shape: torch.Size([1, 1024, 768])
params shape: torch.Size([768]), params grad_sample shape: torch.Size([13, 768])
params shape: torch.Size([768]), params grad_sample shape: torch.Size([13, 768])
params shape: torch.Size([768, 2304]), params grad_sample shape: torch.Size([13, 768, 2304])
params shape: torch.Size([2304]), params grad_sample shape: torch.Size([13, 2304])
params shape: torch.Size([768, 768]), params grad_sample shape: torch.Size([13, 768, 768])
params shape: torch.Size([768]), params grad_sample shape: torch.Size([13, 768])
params shape: torch.Size([768]), params grad_sample shape: torch.Size([13, 768])
params shape: torch.Size([768]), params grad_sample shape: torch.Size([13, 768])
params shape: torch.Size([768, 3072]), params grad_sample shape: torch.Size([13, 768, 3072])
params shape: torch.Size([3072]), params grad_sample shape: torch.Size([13, 3072])
params shape: torch.Size([3072, 768]), params grad_sample shape: torch.Size([13, 3072, 768])
params shape: torch.Size([768]), params grad_sample shape: torch.Size([13, 768])
params shape: torch.Size([768]), params grad_sample shape: torch.Size([13, 768])
params shape: torch.Size([768]), params grad_sample shape: torch.Size([13, 768])
params shape: torch.Size([768, 2304]), params grad_sample shape: torch.Size([13, 768, 2304])
params shape: torch.Size([2304]), params grad_sample shape: torch.Size([13, 2304])
params shape: torch.Size([768, 768]), params grad_sample shape: torch.Size([13, 768, 768])
params shape: torch.Size([768]), params grad_sample shape: torch.Size([13, 768])
params shape: torch.Size([768]), params grad_sample shape: torch.Size([13, 768])
params shape: torch.Size([768]), params grad_sample shape: torch.Size([13, 768])
params shape: torch.Size([768, 3072]), params grad_sample shape: torch.Size([13, 768, 3072])
params shape: torch.Size([3072]), params grad_sample shape: torch.Size([13, 3072])
params shape: torch.Size([3072, 768]), params grad_sample shape: torch.Size([13, 3072, 768])
params shape: torch.Size([768]), params grad_sample shape: torch.Size([13, 768])
params shape: torch.Size([768]), params grad_sample shape: torch.Size([13, 768])
params shape: torch.Size([768]), params grad_sample shape: torch.Size([13, 768])
params shape: torch.Size([768, 2304]), params grad_sample shape: torch.Size([13, 768, 2304])
params shape: torch.Size([2304]), params grad_sample shape: torch.Size([13, 2304])
params shape: torch.Size([768, 768]), params grad_sample shape: torch.Size([13, 768, 768])
params shape: torch.Size([768]), params grad_sample shape: torch.Size([13, 768])
params shape: torch.Size([768]), params grad_sample shape: torch.Size([13, 768])
params shape: torch.Size([768]), params grad_sample shape: torch.Size([13, 768])
params shape: torch.Size([768, 3072]), params grad_sample shape: torch.Size([13, 768, 3072])
params shape: torch.Size([3072]), params grad_sample shape: torch.Size([13, 3072])
params shape: torch.Size([3072, 768]), params grad_sample shape: torch.Size([13, 3072, 768])
params shape: torch.Size([768]), params grad_sample shape: torch.Size([13, 768])
params shape: torch.Size([768]), params grad_sample shape: torch.Size([13, 768])
params shape: torch.Size([768]), params grad_sample shape: torch.Size([13, 768])
params shape: torch.Size([768, 2304]), params grad_sample shape: torch.Size([13, 768, 2304])
params shape: torch.Size([2304]), params grad_sample shape: torch.Size([13, 2304])
params shape: torch.Size([768, 768]), params grad_sample shape: torch.Size([13, 768, 768])
params shape: torch.Size([768]), params grad_sample shape: torch.Size([13, 768])
params shape: torch.Size([768]), params grad_sample shape: torch.Size([13, 768])
params shape: torch.Size([768]), params grad_sample shape: torch.Size([13, 768])
params shape: torch.Size([768, 3072]), params grad_sample shape: torch.Size([13, 768, 3072])
params shape: torch.Size([3072]), params grad_sample shape: torch.Size([13, 3072])
params shape: torch.Size([3072, 768]), params grad_sample shape: torch.Size([13, 3072, 768])
params shape: torch.Size([768]), params grad_sample shape: torch.Size([13, 768])
params shape: torch.Size([768]), params grad_sample shape: torch.Size([13, 768])
params shape: torch.Size([768]), params grad_sample shape: torch.Size([13, 768])
params shape: torch.Size([768, 2304]), params grad_sample shape: torch.Size([13, 768, 2304])
params shape: torch.Size([2304]), params grad_sample shape: torch.Size([13, 2304])
params shape: torch.Size([768, 768]), params grad_sample shape: torch.Size([13, 768, 768])
params shape: torch.Size([768]), params grad_sample shape: torch.Size([13, 768])
params shape: torch.Size([768]), params grad_sample shape: torch.Size([13, 768])
params shape: torch.Size([768]), params grad_sample shape: torch.Size([13, 768])
params shape: torch.Size([768, 3072]), params grad_sample shape: torch.Size([13, 768, 3072])
params shape: torch.Size([3072]), params grad_sample shape: torch.Size([13, 3072])
params shape: torch.Size([3072, 768]), params grad_sample shape: torch.Size([13, 3072, 768])
params shape: torch.Size([768]), params grad_sample shape: torch.Size([13, 768])
params shape: torch.Size([768]), params grad_sample shape: torch.Size([13, 768])
params shape: torch.Size([768]), params grad_sample shape: torch.Size([13, 768])
params shape: torch.Size([768, 2304]), params grad_sample shape: torch.Size([13, 768, 2304])
params shape: torch.Size([2304]), params grad_sample shape: torch.Size([13, 2304])
params shape: torch.Size([768, 768]), params grad_sample shape: torch.Size([13, 768, 768])
params shape: torch.Size([768]), params grad_sample shape: torch.Size([13, 768])
params shape: torch.Size([768]), params grad_sample shape: torch.Size([13, 768])
params shape: torch.Size([768]), params grad_sample shape: torch.Size([13, 768])
params shape: torch.Size([768, 3072]), params grad_sample shape: torch.Size([13, 768, 3072])
params shape: torch.Size([3072]), params grad_sample shape: torch.Size([13, 3072])
params shape: torch.Size([3072, 768]), params grad_sample shape: torch.Size([13, 3072, 768])
params shape: torch.Size([768]), params grad_sample shape: torch.Size([13, 768])
params shape: torch.Size([768]), params grad_sample shape: torch.Size([13, 768])
params shape: torch.Size([768]), params grad_sample shape: torch.Size([13, 768])

How can I solve this problem? Any answer is greatly appreciated.