Dual BERT/Two Channel BERT cuda out of memory

Hello,

I am using PyTorch for a BERT model. The input data is two types: customer review or agent reply. Given the customer review is more important and is already exceed 512 limitation, I don’t want to concatenate two different tex input together.
I can’t find what exact term to describe the technique I am trying, but basically I want to try three versions:
Version 1: take the customer review as input for the basic Bert model, add one binary classifier (Label can only be yes or no, 1 class) on the top and use [CLS] token to make a prediction.

Version 2: Concat Agent: means take customer review data and Agent data as input, and share the parameter in Bert. Then concatenate two embedding vector of [CLS] tokens, one for customer review data, and another for agent data. After concatenating it, add the binary classifier on the top.

Version 3: Attention Agent: take customer review data and agent data as input, and share the parameter in Bert. For the hidden state of the agent data, use mean pooling to get the sentence embedded of Agent data, and take it as query in attention mechanism. For the hidden state of review data, we take it as key and value in attention mechanism. After the attention, we add binary classifier on the top.

Currently my base model (version 1) works with no problem. However, if I run the two other versions, I always get a CUDA out of memory error. I am not sure if it’s my code has issues or the SageMaker setting is limited for this task.

I am wondering if any one can help me with this issue?

Thank you for taking your time to read through this. Any directions help! It’s just hard for me to search around any else is doing the same thing.

Below is my code for the model part:

class ReviewClassification(BertPreTrainedModel):
    def __init__(self, config,
                 add_agent_text, agent_text_heads):
      
        super().__init__(config)
        # self.num_labels = 2
        self.add_agent_text = add_agent_text

        self.bert = BertModel(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

        embedding_size = config.hidden_size

        if self.add_agent_text == "concat":
            embedding_size = 2 * embedding_size
        elif self.add_agent_text == "attention":
            self.agent_attention = nn.MultiheadAttention(embedding_size, num_heads=agent_text_heads)
        else:
            # don't use the information in Agent text
            pass


        self.classifier = nn.Linear(embedding_size, 1) 
        self.init_weights()

    def forward(
            self,
            review_input_ids=None,
            review_attention_mask=None,
            review_token_type_ids=None,
            agent_input_ids=None,
            agent_attention_mask=None,
            agent_token_type_ids=None,
            labels=None,
    ):
     
        review_outputs = self.bert(
            review_input_ids,
            attention_mask=review_attention_mask,
            token_type_ids=review_token_type_ids,
            position_ids=None,
            head_mask=None,
            inputs_embeds=None,
        )
        if self.add_agent_text is not None:
            # means that self.add_agent_text is "concat" or "attention"
            # TODO: we can try that agent_outputs do not share the same parameter
            agent_outputs = self.bert(
                agent_input_ids,
                attention_mask=agent_attention_mask,
                token_type_ids=agent_token_type_ids,
                position_ids=None,
                head_mask=None,
                inputs_embeds=None,
            )
        if self.add_agent_text == "attention":
            # want to take it as key and value, we need to transpose its shape according to the document
            # https://pytorch.org/docs/master/generated/torch.nn.MultiheadAttention.html
            review_hidden_states = review_outputs[0].transpose(0, 1)  # before trans: (bs, seq_len, hidden_size)

            # want to take it as query, we need the it has the shape (#target_seq_len, batch_size, embedding_size)
            agent_hidden_states = agent_outputs[0].mean(axis=1).unsqueeze(dim=0)  # (1, batch_size, hidden_size)

            attn_output, _ = self.agent_attention(agent_hidden_states, review_hidden_states, review_hidden_states)
            feature = attn_output.squeeze()  # (batch_size, seq_len)
        else:
            # don't use the attention mechanism
            # have two options in here to make classification:
            # 1. only use the first CLS token to make classification
            feature = review_outputs[1]  # (batch_size, seq_len) -? Should it be (batch_size, hidden_size)
            # 2. use mean of the hidden state
            # feature = review_outputs[0].mean(axis=1)

        if self.add_agent_text == "concat":
            feature = torch.cat([feature, agent_outputs[1]], axis=1)
       
        logits = self.classifier(feature).squeeze()

        outputs = (logits,)  # + outputs[2:]  # add hidden states and attention if they are here


        if labels is not None:
            pos_weight=torch.tensor(8.85)

            loss_fct = nn.BCEWithLogitsLoss(pos_weight=pos_weight).cuda()
            loss = loss_fct(logits, labels)
            outputs = (loss,) + outputs
 
        return outputs  # (loss, logits, hidden_states, attentions)