Calculate gradients for BERT fine-tune on multiple GPUs

Hi,

Problem Description: I add a layer on top of BERT and fine-tune BERT (unfreeze all layers) for the downstream classification task. The input data contains long documents. So I divide each document into a list of tensors of 512. I pass one document as one batch in the forward pass with requires_grad=True for BERT fine-tuning. Please note that somehow each chunk creates its copy of BERT and hence it ends up taking so much memory. So a document with 40 chunks will create 40 instances of the BERT model. Since some documents are really long, all chunk tensors don’t fit on one GPU so I divide a single input across multiple GPUs (2 max) manually where one GPU can fit max. of 40 chunks of one input.

The issue is with the backward pass since the loss doesn’t know which CUDA to backpropagate it to.

Also, DataParallel sounds like a good way to implement this but I am not quite sure because it breaks the input batch into minibatches, unlike my case. In other words, the data used with DataParallel itself has say, 120 documents with 120 labels. So, if I have 3 GPUs, each GPU will have 40 documents/labels and the loss can be computed for each minibatch from the forward pass. Then all the losses are summed/averaged and passed to the main GPU (say GPU:0). More explanation: here. However, in my case, I am splitting a single input (document) with a single label across 2 GPUs and so I am not sure how the individual GPU loss can be computed.

Any quick help/suggestions would be appreciated: @ptrblck @albanD

Below is the class snippet:

class BERT_Arch(nn.Module):

def __init__(self, bert):

    super(BERT_Arch, self).__init__()

    self.bert = bert 
  
    # dropout layer
    self.dropout = nn.Dropout(0.1)

    # dense layer 1
    self.fc1 = nn.Linear(768,1) 

#define the forward pass
def forward(self, sent_id, mask, hist):

    torch.cuda.empty_cache()
    chunk_num = len(sent_id)
    print("# chunks: ", chunk_num)

    flag1 = False

    cls_vec1 = []
    cls_vec2 = []
    cls_mean_all = []

    for i in range(len(sent_id)):

        print("chunk id: ", i)

        if i < 40:

            device = "cuda:1"

            ip_id = torch.tensor(sent_id[i]).unsqueeze(0).to(device)
            attn_mask = torch.tensor(mask[i]).unsqueeze(0).to(device)

            #pass the inputs to the model  
            outputs = self.bert(input_ids=ip_id, attention_mask=attn_mask)

            cls_hs = outputs[0][:, 0, :]

            cls_vec1.append(cls_hs)

            del cls_hs
            gc.collect()
            torch.cuda.empty_cache()

        
        elif i > 40 and i < 80:

            flag1 = True

            device = "cuda:2"

            ip_id = torch.tensor(sent_id[i]).unsqueeze(0).to(device)
            attn_mask = torch.tensor(mask[i]).unsqueeze(0).to(device)

            #pass the inputs to the model  
            outputs = self.bert(input_ids=ip_id, attention_mask=attn_mask)

            cls_hs = outputs[0][:, 0, :]

            cls_vec2.append(cls_hs)

            del cls_hs
            gc.collect()
            torch.cuda.empty_cache()

    cls_mean1 = torch.mean(torch.stack(cls_vec1, dim=0), dim=0)
    x = cls_mean1

    if flag1:
        cls_mean2 = torch.mean(torch.stack(cls_vec2, dim=0), dim=0)

        cls_mean_all.append(cls_mean1)
        cls_mean_all.append(cls_mean2)

        cls_mean = torch.mean(torch.stack(cls_mean_all, dim=0), dim=0)

        x = cls_mean

    x = self.dropout(x)
    y = self.fc1(x)

    return y

device = “cuda:0”
LM = ‘bert-base-uncased’
base_model = AutoModel.from_pretrained(LM)
for param in base_model.parameters():
param.requires_grad = True
model = BERT_Arch(base_model)
model.to(device)

Thanks!