Distributed Data Parallel with Triplet Transformer Model

Hello, I’m working on a Triplet model for semantic search, and I experience huge problems with training it using DistributedDataParallel module. Essentially model looks like that (actual model is much more complicated, as it contains i.e., Transformer inside, but for the sake of simplicity below is a simplified scheme):

> class TripletModel(nn.Module):
>     def __init__(self, encoder):
>         super(TripletModel, self).__init__()
>         self.encoder = encoder
> 
>     def forward(self, x1, x2, x3):
>         anchor_emb = self.encoder(x1)
>         positive_pair_emb = self.encoder(x2)
>         negative_pair_emb = self.encoder(x3)
>         return anchor_emb, positive_pair_emb, negative_pair_emb

representations for triplets are created by the same encoder, and compared with TripletMarginLoss and cosine distance as distance metric. I wanted to distribute the training on two GPUs (on the same node) using DistributedDataParallel, but during the training I receive following error:

RuntimeError: Expected to mark a variable ready only once. This error is caused by one of the following reasons: 
1) Use of a module parameter outside the `forward` function. Please make sure model parameters are not shared across multiple concurrent forward-backward passes. or try to use _set_static_graph() as a workaround if this module graph does not change during training loop.
2) Reused parameters in multiple reentrant backward passes. For example, if you use multiple `checkpoint` functions to wrap the same part of your model, it would result in the same set of parameters been used by different reentrant backward passes multiple times, and hence marking a variable ready multiple times. DDP does not support such use cases in default. You can try to use _set_static_graph() as a workaround if your module graph does not change over iterations.
Parameter at index 99 with name backbone.encoder.encoder.layer.5.output.LayerNorm.weight has been marked as ready twice. This means that multiple autograd engine  hooks have fired for this particular parameter during this iteration.

Have anyone worked on similar problem? I’m training the model using HuggingFace Trainer API, with gradient checkpointing enabled, and find_unused_parameters parameter in DistributedDataParallel set to False (I have also test inside the training loop that checks for unused parameters, so there shouldn’t be any). Most of the discussions I’ve found online (i.e., does Gradient checkpointing support multi-gpu ? · Issue #63 · allenai/longformer · GitHub) point to this particular parameter (find_unused_parameters), but it is already set to False, and for sure there are no additional unused parameters

Additionally, if I turn off the gradient_checkpointing, I receive following error:

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.LongTensor [4, 447]] is at version 3; expected version 2 instead. 
Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

The backtrace points to embedding layer of token_type_embeddings layer in Roberta model of Huggingface Transformers (modeling_roberta.py module) :

[...]lib/python3.6/site-packages/transformers/models/roberta/modeling_roberta.py", line 132, in forward
token_type_embeddings = self.token_type_embeddings(token_type_ids)
[...]
[...]lib/python3.6/site-packages/torch/nn/functional.py", line 2043, in embedding
return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)

During training on single GPU I receive no errors, and the training normally progresses

Thanks for your post! Are you using weight sharing in your module? Looking at does Gradient checkpointing support multi-gpu ? · Issue #63 · allenai/longformer · GitHub, it mentions that DDP does not work with gradient checkpointing + weight sharing in some cases, but we would need a more detailed reproduction to confirm the issue.

If you can get a repro of the issue, it would be great to file an issue at Issues · pytorch/pytorch · GitHub so we can look into it.

Thanks for the response! I’ve run additional test, in which I replaced Transformer model with single linear layer encoder, and everything ran without any problems, so it must be related only to the Transformer. I’m not sure about the weight sharing - I’ve looked at the Roberta model implementation in huggingface (I’m training distilroberta as an encoder): transformers/modeling_roberta.py at master · huggingface/transformers · GitHub , and I don’t see there any weight sharing. Additionally, with gradient_checkpointing enabled, the backtrace points to the LayerNorm layer of RobertaOutput class (5th layer of the model). I’ll try to provide a reproducible example

Here is a reproducible example with the second Error (inplace operation), assuming you have two GPUs available (what’s interesting, now it doesn’t throw the first error related to the ‘read only once’, even though the gradient checkpointing is still turned on). Some parts of the code were simplified, just to provide a working example (i.e. random data creation and sampling through it - I know it could’ve been done more carefully, but wanted to provide working example as quickly as possible):

import os

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.multiprocessing as mp
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler

from transformers import AutoConfig, AutoModel

def run_training(rank):
    torch.autograd.set_detect_anomaly(True)
    torch.distributed.init_process_group(backend="nccl", rank=rank)
    
    device = torch.device('cuda', rank)
    
    length_inp = 256
    data = [{'sent1': {'input_ids': torch.randint(0, 200, (16, length_inp)), 'attention_mask': torch.ones(16, length_inp)}, 
             'sent2': {'input_ids': torch.randint(0, 200, (16, length_inp)), 'attention_mask': torch.ones(16, length_inp)},
             'sent3': {'input_ids': torch.randint(0, 200, (16, length_inp)), 'attention_mask': torch.ones(16, length_inp)}
            }]*8
    
    train_sampler = DistributedSampler(
                    data,
                    num_replicas=torch.cuda.device_count(),
                    rank=rank,
                    seed=44,
                )
    
    dll = DataLoader(
            data,
            batch_size=1,
            sampler=train_sampler,
            drop_last=False,
            num_workers=2,
            pin_memory=True,
        )

    config = AutoConfig.from_pretrained('distilroberta-base')
    model = AutoModel.from_pretrained('distilroberta-base', config=config, add_pooling_layer=False)
    model.to(device)
    model.gradient_checkpointing_enable()
    
    model = nn.parallel.DistributedDataParallel(model, device_ids=[rank], output_device=rank, find_unused_parameters=False)
    optimizer = torch.optim.Adam(model.parameters())
    metric = lambda x,y: 1.0 - F.cosine_similarity(x, y)
    criterion = nn.TripletMarginWithDistanceLoss(distance_function=metric, margin=0.2, reduction='none')
    
    for n, b in enumerate(dll):
        print(n)
        model.zero_grad()
        sent1 = {k:v.squeeze(0) for k,v in b['sent1'].items()}
        sent2 = {k:v.squeeze(0) for k,v in b['sent2'].items()}
        sent3 = {k:v.squeeze(0) for k,v in b['sent3'].items()}
        
        emb1 = model(**sent1)[0][:, 0, :]
        emb2 = model(**sent2)[0][:, 0, :]
        emb3 = model(**sent3)[0][:, 0, :]
        
        losses = criterion(emb1, emb2, emb3)
        loss = losses.mean()
        loss.backward()
        
        optimizer.step()
        print('Model device: {}, loss device: {}, loss: {}'.format(model.device, loss.device, loss))


def main():
    world_size = torch.cuda.device_count()
    os.environ["MASTER_PORT"] = '1234'
    os.environ["MASTER_ADDR"] = '127.0.0.1'
    os.environ["WORLD_SIZE"] = str(world_size)
    mp.spawn(run_training,
        nprocs=world_size,
        join=True)

if __name__ == "__main__":
    main()

I hope it is enough, I also wasnt sure if this problem is suitable as an issue on github, if it is, please let me know :slight_smile: