Issue with FSDP and autograd when computing total loss with embeddings

I’m currently working on a project that involves using FSDP (Fully Sharded Data Parallelism) for distributed training in PyTorch. I’m trying to compute the total loss, which involves a normal BCE loss and an additional loss obtained using autograd. Specifically, I’m using embeddings as input and the logits as output for this loss. I’ve encountered an issue, and I’ve traced it down to a problem with FSDP not communicating nicely with autograd.

Here’s the relevant code snippet:

python code

import torch
import torch.nn as nn
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from accelerate import Accelerator

accelerator = Accelerator()

model = AutoModelForSequenceClassification.from_pretrained('distilbert-base-uncased').to(accelerator.device)
model = accelerator.prepare(model)

tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')

optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
optimizer = accelerator.prepare(optimizer)

criterion = nn.BCEWithLogitsLoss()

# Dummy inputs
input_text = ["This is a test sentence.", "Another test sentence."]
input_ids = tokenizer(chosen_input_text, padding=True, truncation=True, return_tensors='pt')['input_ids'].to(accelerator.device)
attention_mask = chosen_input_ids != tokenizer.pad_token_id
attention_mask =

# Generate synthetic logits for "chosen" and "rejected" examples
output = model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
output_logits = output.logits[:, 1]

# Compute GAIL loss
loss = criterion(output_logits, torch.ones_like(output_logits))

output_embeds = output.hidden_states[0]

# Compute gradient penalty (grad_pen)
alpha = torch.rand(output_embeds.size(0), 1, 1).expand(output_embeds.size()).to(accelerator.device)
mixup_data_embeds = alpha * output_embeds

# Compute discriminator output using the BertForSequenceClassification model
disc_mixup_output = model(inputs_embeds = mixup_data_embeds)
disc_mixup_output_logits = disc_mixup_output.logits

# Compute gradient of disc_mixup with respect to mixup_data_embeds
ones = torch.ones_like(disc_mixup_output_logits)
grad = torch.autograd.grad(

# Compute the norm-based gradient penalty
lambda_ = 0.1
grad_pen = ((grad.norm(2, dim=1) - 1) ** 2).mean() * lambda_

total_loss = (loss + grad_pen)


However, when I run this code, I get the following error:

bash output

Traceback (most recent call last):
  File "", line 60, in <module>
  File "/home/user/.local/lib/python3.8/site-packages/accelerate/", line 1853, in backward
  File "/home/user/.local/lib/python3.8/site-packages/torch/", line 487, in backward
  File "/home/user/.local/lib/python3.8/site-packages/torch/autograd/", line 200, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/home/user/.local/lib/python3.8/site-packages/torch/distributed/fsdp/", line 643, in _pre_backward_hook
    _prefetch_handles(state, _handles_key)
  File "/home/user/.local/lib/python3.8/site-packages/torch/distributed/fsdp/", line 1003, in _prefetch_handles
    handles_to_prefetch = _get_handles_to_prefetch(state, current_handles_key)
  File "/home/user/.local/lib/python3.8/site-packages/torch/distributed/fsdp/", line 1046, in _get_handles_to_prefetch
    target_handles_keys = [
TypeError: 'NoneType' object is not iterable
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 2715808) of binary: /share/apps/anaconda3/2021.05/bin/python
Traceback (most recent call last):
  File "/home/user/.local/bin/accelerate", line 8, in <module>
  File "/home/user/.local/lib/python3.8/site-packages/accelerate/commands/", line 45, in main
  File "/home/user/.local/lib/python3.8/site-packages/accelerate/commands/", line 966, in launch_command
  File "/home/user/.local/lib/python3.8/site-packages/accelerate/commands/", line 646, in multi_gpu_launcher
  File "/home/user/.local/lib/python3.8/site-packages/torch/distributed/", line 785, in run
  File "/home/user/.local/lib/python3.8/site-packages/torch/distributed/launcher/", line 134, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/home/user/.local/lib/python3.8/site-packages/torch/distributed/launcher/", line 250, in launch_agent
    raise ChildFailedError(

accelerate config

compute_environment: LOCAL_MACHINE
distributed_type: FSDP
downcast_bf16: 'no'
  fsdp_auto_wrap_policy: SIZE_BASED_WRAP
  fsdp_backward_prefetch_policy: BACKWARD_PRE
  fsdp_min_num_params: 1000
  fsdp_offload_params: false
  fsdp_sharding_strategy: 2
  fsdp_state_dict_type: FULL_STATE_DICT
machine_rank: 0
main_training_function: main
mixed_precision: 'no'
num_machines: 1
num_processes: 2
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

I have dialed down the issue to be the torch.autograd.grad function when FSDP is added to the accelerate config.

I suspect that the issue may be related to the way I’m using autograd for this specific loss with embeddings as input and logits as output, but I’m not entirely sure. Has anyone encountered a similar problem or can offer insights into what I might be doing wrong?

Additional Information:

  • PyTorch version: 2.0.1+cu117
  • GPUs: 2x A5000

Any help or guidance on this issue would be greatly appreciated. Thank you!

I would recommend posting this as an issue in GitHub - huggingface/accelerate: 🚀 A simple way to train and use PyTorch models with multi-GPU, TPU, mixed-precision.