Inplace operation error when backwarding gradient with DDP

I am only getting an error message when I use DDP to train with 1 or more GPUs
I am using DDP with Huggingface accelerate

2023-08-22 12:46:17.628072: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable TF_ENABLE_ONEDNN_OPTS=0.
2023-08-22 12:46:17.694424: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-08-22 12:46:19.781103: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
2023-08-22 12:46:38.580436: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: [‘vocab_layer_norm.bias’, ‘vocab_layer_norm.weight’, ‘vocab_projector.bias’, ‘vocab_transform.weight’, ‘vocab_transform.bias’, ‘vocab_projector.weight’]

  • This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
  • This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
    Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: [‘pre_classifier.weight’, ‘classifier.weight’, ‘pre_classifier.bias’, ‘classifier.bias’]
    You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
    /home/am2448/.local/lib/python3.8/site-packages/torch/autograd/init.py:200: UserWarning: Error detected in EmbeddingBackward0. Traceback of forward call that caused the error:
    File “debug_autograd_bert.py”, line 30, in
    output = model(input_ids=chosen_input_ids, attention_mask=chosen_attention_mask, output_hidden_states=True)
    File “/home/am2448/.local/lib/python3.8/site-packages/torch/nn/modules/module.py”, line 1501, in _call_impl
    return forward_call(*args, **kwargs)
    File “/home/am2448/.local/lib/python3.8/site-packages/torch/nn/parallel/distributed.py”, line 1156, in forward
    output = self._run_ddp_forward(*inputs, **kwargs)
    File “/home/am2448/.local/lib/python3.8/site-packages/torch/nn/parallel/distributed.py”, line 1110, in _run_ddp_forward
    return module_to_run(*inputs[0], **kwargs[0]) # type: ignore[index]
    File “/home/am2448/.local/lib/python3.8/site-packages/torch/nn/modules/module.py”, line 1501, in _call_impl
    return forward_call(*args, **kwargs)
    File “/home/am2448/.local/lib/python3.8/site-packages/transformers/models/distilbert/modeling_distilbert.py”, line 751, in forward
    distilbert_output = self.distilbert(
    File “/home/am2448/.local/lib/python3.8/site-packages/torch/nn/modules/module.py”, line 1501, in _call_impl
    return forward_call(*args, **kwargs)
    File “/home/am2448/.local/lib/python3.8/site-packages/transformers/models/distilbert/modeling_distilbert.py”, line 570, in forward
    inputs_embeds = self.embeddings(input_ids) # (bs, seq_length, dim)
    File “/home/am2448/.local/lib/python3.8/site-packages/torch/nn/modules/module.py”, line 1501, in _call_impl
    return forward_call(*args, **kwargs)
    File “/home/am2448/.local/lib/python3.8/site-packages/transformers/models/distilbert/modeling_distilbert.py”, line 134, in forward
    position_embeddings = self.position_embeddings(position_ids) # (bs, max_seq_length, dim)
    File “/home/am2448/.local/lib/python3.8/site-packages/torch/nn/modules/module.py”, line 1501, in _call_impl
    return forward_call(*args, **kwargs)
    File “/home/am2448/.local/lib/python3.8/site-packages/torch/nn/modules/sparse.py”, line 162, in forward
    return F.embedding(
    File “/home/am2448/.local/lib/python3.8/site-packages/torch/nn/functional.py”, line 2210, in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
    (Triggered internally at …/torch/csrc/autograd/python_anomaly_mode.cpp:114.)
    Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
    Traceback (most recent call last):
    File “debug_autograd_bert.py”, line 63, in
    accelerator.backward(total_loss)
    File “/home/am2448/.local/lib/python3.8/site-packages/accelerate/accelerator.py”, line 1853, in backward
    loss.backward(**kwargs)
    File “/home/am2448/.local/lib/python3.8/site-packages/torch/_tensor.py”, line 487, in backward
    torch.autograd.backward(
    File “/home/am2448/.local/lib/python3.8/site-packages/torch/autograd/init.py”, line 200, in backward
    Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
    RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.LongTensor [1, 8]] is at version 3; expected version 2 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!
    ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 1168116) of binary: /share/apps/anaconda3/2021.05/bin/python

Here’s my code:

import torch
import torch.nn as nn
from transformers import AutoModelForSequenceClassification, AutoTokenizer
# import huggingface accelerate
from accelerate import Accelerator

torch.autograd.set_detect_anomaly(True)

accelerator = Accelerator()
# Initialize the BertForSequenceClassification model
model = AutoModelForSequenceClassification.from_pretrained('distilbert-base-uncased').to(accelerator.device)
model = accelerator.prepare(model)
accelerator.unwrap_model(model).train()

# Initialize the BertTokenizer
tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')

optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
optimizer = accelerator.prepare(optimizer)

criterion = nn.BCEWithLogitsLoss()

# Define a simple input
chosen_input_text = ["This is a test sentence.", "Another test sentence."]
chosen_input_ids = tokenizer(chosen_input_text, padding=True, truncation=True, return_tensors='pt')['input_ids'].to(accelerator.device)
chosen_attention_mask = chosen_input_ids != tokenizer.pad_token_id
chosen_attention_mask = chosen_attention_mask.to(accelerator.device)

# Generate synthetic logits for "chosen" and "rejected" examples
output = model(input_ids=chosen_input_ids, attention_mask=chosen_attention_mask, output_hidden_states=True)
output_logits = output.logits[:, 1]

# Compute GAIL loss
loss = criterion(output_logits, torch.ones_like(output_logits))

output_embeds = output.hidden_states[0].clone()

# Compute gradient penalty (grad_pen)
lambda_ = 0.1
alpha = torch.rand(output_embeds.size(0), 1, 1).expand(output_embeds.size()).to(accelerator.device)
# print(output_embeds)
mixup_data_embeds = alpha * output_embeds
# print(mixup_data_embeds)
# Compute discriminator output using the BertForSequenceClassification model
disc_mixup_output = model(inputs_embeds = mixup_data_embeds)
disc_mixup_output_logits = disc_mixup_output.logits

# Compute gradient of disc_mixup with respect to mixup_data_embeds
ones = torch.ones_like(disc_mixup_output_logits)
grad = torch.autograd.grad(
    outputs=disc_mixup_output_logits,
    inputs=mixup_data_embeds,
    grad_outputs=ones,
    create_graph=True,
)[0]

# Compute the norm-based gradient penalty
grad_pen = ((grad.norm(2, dim=1) - 1) ** 2).mean() * lambda_

total_loss = (loss + grad_pen)

optimizer.zero_grad()
accelerator.backward(total_loss)
optimizer.step()

My accelerate config looks like this:

compute_environment: LOCAL_MACHINE
distributed_type: MULTI_GPU
machine_rank: 0
main_training_function: main
mixed_precision: 'no'
num_machines: 1
num_processes: 1
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

and the command I put in my terminal to run the script is

accelerate launch --config_file accelerate_config.yaml \
 --main_process_port 29631 test.py

Any help would be very much appreciated!