RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.LongTensor [1, 54]] is at version 3; expected version 2 instead. Hint: the backtrace further above shows the operation that failed to

I am trying to finetune CLIP by giving it 2 images and text as inputs. I keep getting an error (above) and I have looked at other posts on the forum but none of them seem to help. Here is the stacktrace:

/usr/local/lib/python3.10/dist-packages/torch/autograd/init.py:251: UserWarning: Error detected in EmbeddingBackward0. Traceback of forward call that caused the error:
File “/workspace/LLaVA/finetune_CLIP_DDP_2I.py”, line 256, in
main()
File “/workspace/LLaVA/finetune_CLIP_DDP_2I.py”, line 253, in main
train(local_rank, world_size, subset_data, eval_data) #Currently for subset
File “/workspace/LLaVA/finetune_CLIP_DDP_2I.py”, line 86, in train
outputs1 = model(input_ids=input_ids, pixel_values=pixel_values1)
File “/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py”, line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File “/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py”, line 1527, in _call_impl
return forward_call(*args, **kwargs)
File “/usr/local/lib/python3.10/dist-packages/torch/nn/parallel/distributed.py”, line 1519, in forward
else self._run_ddp_forward(*inputs, **kwargs)
File “/usr/local/lib/python3.10/dist-packages/torch/nn/parallel/distributed.py”, line 1355, in _run_ddp_forward
return self.module(*inputs, **kwargs) # type: ignore[index]
File “/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py”, line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File “/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py”, line 1527, in _call_impl
return forward_call(*args, **kwargs)
File “/usr/local/lib/python3.10/dist-packages/transformers/models/clip/modeling_clip.py”, line 1108, in forward
text_outputs = self.text_model(
File “/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py”, line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File “/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py”, line 1527, in _call_impl
return forward_call(*args, **kwargs)
File “/usr/local/lib/python3.10/dist-packages/transformers/models/clip/modeling_clip.py”, line 691, in forward
hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
File “/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py”, line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File “/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py”, line 1527, in _call_impl
return forward_call(*args, **kwargs)
File “/usr/local/lib/python3.10/dist-packages/transformers/models/clip/modeling_clip.py”, line 218, in forward
position_embeddings = self.position_embedding(position_ids)
File “/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py”, line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File “/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py”, line 1527, in _call_impl
return forward_call(*args, **kwargs)
File “/usr/local/lib/python3.10/dist-packages/torch/nn/modules/sparse.py”, line 162, in forward
return F.embedding(
File “/usr/local/lib/python3.10/dist-packages/torch/nn/functional.py”, line 2233, in embedding
return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
(Triggered internally at …/torch/csrc/autograd/python_anomaly_mode.cpp:114.)
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
Traceback (most recent call last):
File “/workspace/LLaVA/finetune_CLIP_DDP_2I.py”, line 256, in
main()
File “/workspace/LLaVA/finetune_CLIP_DDP_2I.py”, line 253, in main
train(local_rank, world_size, subset_data, eval_data) #Currently for subset
File “/workspace/LLaVA/finetune_CLIP_DDP_2I.py”, line 110, in train
loss.backward()
File “/usr/local/lib/python3.10/dist-packages/torch/_tensor.py”, line 492, in backward
torch.autograd.backward(
File “/usr/local/lib/python3.10/dist-packages/torch/autograd/init.py”, line 251, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.LongTensor [1, 54]] is at version 3; expected version 2 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!
/usr/local/lib/python3.10/dist-packages/torch/autograd/init.py:251: UserWarning: Error detected in EmbeddingBackward0.

Based on the stacktrace, it seems that the error lies in outputs1.

Here is my code:

class CustomDataset(Dataset):
    def __init__(self, data, processor):
        self.data = data
        self.processor = processor

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        print(self.data[idx])
        image_paths, text = self.data[idx]
        image_path = "../LLM_annots/dataset/images_eating_full/images/" + image_paths[0]
        image_path2 = "../LLM_annots/dataset/images_eating_full/images/" + image_paths[1]
        image = Image.open(image_path).convert("RGB")
        image2 = Image.open(image_path2).convert("RGB")
        inputs = self.processor(text=[text], images=[image, image2], return_tensors="pt", padding=True)
        input_ids = inputs['input_ids'].squeeze(0)
        pixel_values1 = inputs['pixel_values'][0]
        pixel_values2 = inputs['pixel_values'][1]
        
        return input_ids, pixel_values1, pixel_values2

def setup(rank, world_size):
    dist.init_process_group(backend="nccl", init_method="env://", rank=rank, world_size=world_size)


def cleanup():
    dist.destroy_process_group()

def collate_fn(batch):
    input_ids, pixel_values1, pixel_values2 = zip(*batch)
    input_ids_padded = pad_sequence(input_ids, batch_first=True, padding_value=0)
    pixel_values1_stacked = torch.stack(pixel_values1)
    pixel_values2_stacked = torch.stack(pixel_values2)
    
    return input_ids_padded, pixel_values1_stacked, pixel_values2_stacked

def train(rank, world_size, data, eval_data):
    setup(rank, world_size)
    torch.autograd.set_detect_anomaly(True)
    
    processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14-336")
    model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14-336")
    model = model.to(rank)
    model = DDP(model, device_ids=[rank])
    
    dataset = CustomDataset(data, processor)
    sampler = torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=world_size, rank=rank)
    dataloader = DataLoader(dataset, batch_size=8, sampler=sampler, collate_fn=collate_fn)  # Adjusted batch size

    eval_dataset = CustomDataset(eval_data, processor)
    eval_sampler = torch.utils.data.distributed.DistributedSampler(eval_dataset, num_replicas=world_size, rank=rank)
    eval_dataloader = DataLoader(eval_dataset, batch_size=8, sampler=eval_sampler, collate_fn=collate_fn)  # Adjusted batch size
    
    optimizer = optim.AdamW(model.parameters(), lr=1e-5)  # Adjusted learning rate
    loss_fn = nn.CrossEntropyLoss()

    model.train()
    # acc_steps = 4
    for epoch in range(10):
        for input_ids, pixel_values1, pixel_values2 in dataloader:
            input_ids = input_ids.to(rank)
            pixel_values1 = pixel_values1.to(rank)
            pixel_values2 = pixel_values2.to(rank)

            # Check inputs
            if torch.isnan(input_ids).any() or torch.isinf(input_ids).any():
                print(f"NaN or Inf detected in input_ids at batch {i}")
            
            outputs1 = model(input_ids=input_ids, pixel_values=pixel_values1)
            outputs2 = model(input_ids=input_ids, pixel_values=pixel_values2)

            # Check outputs
            if torch.isnan(outputs1.logits_per_image).any() or torch.isinf(outputs1.logits_per_image).any():
                print(f"NaN or Inf detected in outputs1 at batch {i}")
            
            logits_per_image1 = outputs1.logits_per_image
            logits_per_text1 = outputs1.logits_per_text
            logits_per_image2 = outputs2.logits_per_image
            logits_per_text2 = outputs2.logits_per_text
            
            logits_per_image = (logits_per_image1.clone() + logits_per_image2.clone()) / 2
            logits_per_text = (logits_per_text1.clone() + logits_per_text2.clone()) / 2
            
            ground_truth = torch.arange(len(logits_per_image), device=rank)
            
            loss = (loss_fn(logits_per_image.clone(), ground_truth) + loss_fn(logits_per_text.clone(), ground_truth)) / 2
            loss.backward()
            
            
            with torch.no_grad():
                optimizer.zero_grad()
                optimizer.step()

            

            

            if rank == 0:
                print(f"Epoch [{epoch+1}/10], Loss: {loss.item()}")
                
        if rank == 0:
            evaluate(model, processor, eval_dataloader, rank)

    cleanup()


def evaluate(model, processor, dataloader, rank):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for input_ids, pixel_values1, pixel_values2 in dataloader:
            input_ids = input_ids.to(rank)
            pixel_values1 = pixel_values1.to(rank)
            pixel_values2 = pixel_values2.to(rank)
            
            outputs1 = model(input_ids=input_ids, pixel_values=pixel_values1)
            outputs2 = model(input_ids=input_ids, pixel_values=pixel_values2)
            
            logits_per_image1 = outputs1.logits_per_image
            logits_per_text1 = outputs1.logits_per_text
            logits_per_image2 = outputs2.logits_per_image
            logits_per_text2 = outputs2.logits_per_text
            
            logits_per_image = (logits_per_image1.clone() + logits_per_image2.clone()) / 2
            logits_per_text = (logits_per_text1.clone() + logits_per_text2.clone()) / 2
            
            ground_truth = torch.arange(len(logits_per_image), device=rank)
            _, predicted = torch.max(logits_per_image, 1)
            correct += (predicted == ground_truth).sum().item()
            total += ground_truth.size(0)
    
    accuracy = 100 * correct / total
    print(f'Evaluation Accuracy: {accuracy:.2f}%')

def main():
    json_file = '../LLM_annots/dataset/images_eating_full/eating_labels_bal_ne.json'
    list_data = create_image_caption_tuples(json_file)
    with open(json_file, 'r') as file:
        data = json.load(file)
    print(len(list_data))
    subset_data, eval_data = sample_items(list_data)
    
    
    world_size = int(os.getenv('WORLD_SIZE', '1'))  # Set world size from environment
    rank = int(os.getenv('RANK', '0'))  # Set rank from environment
    local_rank = int(os.getenv('LOCAL_RANK', '0'))  # Set local rank from environment
    train(local_rank, world_size, subset_data, eval_data) #Currently for subset

if __name__ == "__main__":
    main()

Try moving the optimizer.zero_grad() to just after the inner loop like this:

for input_ids, pixel_values1, pixel_values2 in dataloader:`
    optimizer.zero_grad()
    input_ids = input_ids.to(rank)

and have optimizer.step() just after loss.backward()

Also lose the with torch.no_grad() and clone() calls. Not really required.

1 Like

Hi Rohan!

Depending on the details, you might be able to fix your issue using pytorch’s
sweep-inplace-modification-errors-under-the-rug context manager.

You might try:

    model = DDP (model, broadcast_buffers = False, device_ids = [rank])

Does your error occur only when you use DDP (or when you use DDP with
more than one gpu)?

Apparently DDP’s broadcast_buffers can count as an inplace modification,
possible leading to your error. (See, for example, this github issue.)

If you need to further debug your issue, start with this. The problem tensor
is a LongTensor, so maybe it’s being used for indexing. Look in your code
for things like argmax() or argsort() that return index values that you then
might modify and use to index other tensors in your forward pass.

Note that my_indices[3:5] = 0 or my_indices += 1 would both count
as inplace modifications.

In particular, can you locate any LongTensors that have the reported shape
of [1, 54]?

Good luck.

K. Frank

2 Likes

Thanks, the broadcast buffers worked!