Idiom classification based on cosine similarity

Hi! I am trying to create a model that takes an input sentence which includes a multi word expression (MWE), produce a vector representation for both the MWE and its context and then based on their cosine similarity predicts if the sentence is idiomatic or not. To encode the expression I tried to take the sum of the individual words’ representation and for the context I am taking the representation of the [CLS] token. However I am running into some mat shape multiplication errors. What am I doing wrong here? Am I not calculating the cosine correctly? In my head everything is clear, but I am not that experienced with torch and transformers :frowning: Any help would be greatly appreciated!

    def __init__(self, labels):
        super(DUAL_ENCODER, self).__init__()
        self.bert_model = DistilBertModel.from_pretrained("bert-base-multilingual-cased")
        self.classifier = nn.Linear(self.bert_model.config.hidden_size, labels)
        
    def forward(self, mwe_input_ids, mwe_attention_mask, context_input_ids, context_attention_mask):
        context = self.bert_model(mwe_input_ids, mwe_attention_mask)
        mwe = self.bert_model(context_input_ids, context_attention_mask)
        #extract hidden states
        #sum words of expression excluding the [CLS] and [SEP] tokens
        mwe_hidden_state = mwe.last_hidden_state[:, 1:-1, :].sum(dim=1)
        context_hidden_state = context.last_hidden_state[:, 0, :] #take the [CLS] token as the whole sentence representation
        print('mwe', mwe_hidden_state.size(), 'context', context_hidden_state.size())
        similarity = torch.cosine_similarity(mwe_hidden_state, context_hidden_state, dim=1)
        print('similarity', similarity.unsqueeze(1).size())
        predictions = self.classifier(similarity.unsqueeze(1))

        return predictions ```

RuntimeError                              Traceback (most recent call last)
<ipython-input-37-d5608ed79afa> in <cell line: 3>()
     14         labels = batch['labels']
     15 
---> 16         outputs = model(mwe_ids, mwe_attn_mask, context_ids, context_attn_mask)
     17         loss = loss_fn(outputs.squeeze(), labels)
     18         loss.backward()

5 frames
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py in forward(self, input)
    112 
    113     def forward(self, input: Tensor) -> Tensor:
--> 114         return F.linear(input, self.weight, self.bias)
    115 
    116     def extra_repr(self) -> str:

RuntimeError: mat1 and mat2 shapes cannot be multiplied (4x1 and 768x4631)