I’m currently working on a project that involves using FSDP (Fully Sharded Data Parallelism) for distributed training in PyTorch. I’m trying to compute the total loss, which involves a normal BCE loss and an additional loss obtained using autograd. Specifically, I’m using embeddings as input and the logits as output for this loss. I’ve encountered an issue, and I’ve traced it down to a problem with FSDP not communicating nicely with autograd.
Here’s the relevant code snippet:
python code
import torch
import torch.nn as nn
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from accelerate import Accelerator
accelerator = Accelerator()
model = AutoModelForSequenceClassification.from_pretrained('distilbert-base-uncased').to(accelerator.device)
model = accelerator.prepare(model)
accelerator.unwrap_model(model).train()
tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
optimizer = accelerator.prepare(optimizer)
criterion = nn.BCEWithLogitsLoss()
# Dummy inputs
input_text = ["This is a test sentence.", "Another test sentence."]
input_ids = tokenizer(chosen_input_text, padding=True, truncation=True, return_tensors='pt')['input_ids'].to(accelerator.device)
attention_mask = chosen_input_ids != tokenizer.pad_token_id
attention_mask = chosen_attention_mask.to(accelerator.device)
# Generate synthetic logits for "chosen" and "rejected" examples
output = model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
output_logits = output.logits[:, 1]
# Compute GAIL loss
loss = criterion(output_logits, torch.ones_like(output_logits))
output_embeds = output.hidden_states[0]
# Compute gradient penalty (grad_pen)
alpha = torch.rand(output_embeds.size(0), 1, 1).expand(output_embeds.size()).to(accelerator.device)
mixup_data_embeds = alpha * output_embeds
# Compute discriminator output using the BertForSequenceClassification model
disc_mixup_output = model(inputs_embeds = mixup_data_embeds)
disc_mixup_output_logits = disc_mixup_output.logits
# Compute gradient of disc_mixup with respect to mixup_data_embeds
ones = torch.ones_like(disc_mixup_output_logits)
grad = torch.autograd.grad(
outputs=disc_mixup_output_logits,
inputs=mixup_data_embeds,
grad_outputs=ones,
create_graph=True,
)[0]
# Compute the norm-based gradient penalty
lambda_ = 0.1
grad_pen = ((grad.norm(2, dim=1) - 1) ** 2).mean() * lambda_
total_loss = (loss + grad_pen)
optimizer.zero_grad()
accelerator.backward(total_loss)
optimizer.step()
However, when I run this code, I get the following error:
bash output
Traceback (most recent call last):
File "debug_autograd_bert.py", line 60, in <module>
accelerator.backward(total_loss)
File "/home/user/.local/lib/python3.8/site-packages/accelerate/accelerator.py", line 1853, in backward
loss.backward(**kwargs)
File "/home/user/.local/lib/python3.8/site-packages/torch/_tensor.py", line 487, in backward
torch.autograd.backward(
File "/home/user/.local/lib/python3.8/site-packages/torch/autograd/__init__.py", line 200, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
File "/home/user/.local/lib/python3.8/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 643, in _pre_backward_hook
_prefetch_handles(state, _handles_key)
File "/home/user/.local/lib/python3.8/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 1003, in _prefetch_handles
handles_to_prefetch = _get_handles_to_prefetch(state, current_handles_key)
File "/home/user/.local/lib/python3.8/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 1046, in _get_handles_to_prefetch
target_handles_keys = [
TypeError: 'NoneType' object is not iterable
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 2715808) of binary: /share/apps/anaconda3/2021.05/bin/python
Traceback (most recent call last):
File "/home/user/.local/bin/accelerate", line 8, in <module>
sys.exit(main())
File "/home/user/.local/lib/python3.8/site-packages/accelerate/commands/accelerate_cli.py", line 45, in main
args.func(args)
File "/home/user/.local/lib/python3.8/site-packages/accelerate/commands/launch.py", line 966, in launch_command
multi_gpu_launcher(args)
File "/home/user/.local/lib/python3.8/site-packages/accelerate/commands/launch.py", line 646, in multi_gpu_launcher
distrib_run.run(args)
File "/home/user/.local/lib/python3.8/site-packages/torch/distributed/run.py", line 785, in run
elastic_launch(
File "/home/user/.local/lib/python3.8/site-packages/torch/distributed/launcher/api.py", line 134, in __call__
return launch_agent(self._config, self._entrypoint, list(args))
File "/home/user/.local/lib/python3.8/site-packages/torch/distributed/launcher/api.py", line 250, in launch_agent
raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
accelerate config
compute_environment: LOCAL_MACHINE
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
fsdp_auto_wrap_policy: SIZE_BASED_WRAP
fsdp_backward_prefetch_policy: BACKWARD_PRE
fsdp_min_num_params: 1000
fsdp_offload_params: false
fsdp_sharding_strategy: 2
fsdp_state_dict_type: FULL_STATE_DICT
machine_rank: 0
main_training_function: main
mixed_precision: 'no'
num_machines: 1
num_processes: 2
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
I have dialed down the issue to be the torch.autograd.grad function when FSDP is added to the accelerate config.
I suspect that the issue may be related to the way I’m using autograd for this specific loss with embeddings as input and logits as output, but I’m not entirely sure. Has anyone encountered a similar problem or can offer insights into what I might be doing wrong?
Additional Information:
- PyTorch version: 2.0.1+cu117
- GPUs: 2x A5000
Any help or guidance on this issue would be greatly appreciated. Thank you!