Hi! I am trying to create a model that takes an input sentence which includes a multi word expression (MWE), produce a vector representation for both the MWE and its context and then based on their cosine similarity predicts if the sentence is idiomatic or not. To encode the expression I tried to take the sum of the individual words’ representation and for the context I am taking the representation of the [CLS] token. However I am running into some mat shape multiplication errors. What am I doing wrong here? Am I not calculating the cosine correctly? In my head everything is clear, but I am not that experienced with torch and transformers Any help would be greatly appreciated!
def __init__(self, labels):
super(DUAL_ENCODER, self).__init__()
self.bert_model = DistilBertModel.from_pretrained("bert-base-multilingual-cased")
self.classifier = nn.Linear(self.bert_model.config.hidden_size, labels)
def forward(self, mwe_input_ids, mwe_attention_mask, context_input_ids, context_attention_mask):
context = self.bert_model(mwe_input_ids, mwe_attention_mask)
mwe = self.bert_model(context_input_ids, context_attention_mask)
#extract hidden states
#sum words of expression excluding the [CLS] and [SEP] tokens
mwe_hidden_state = mwe.last_hidden_state[:, 1:-1, :].sum(dim=1)
context_hidden_state = context.last_hidden_state[:, 0, :] #take the [CLS] token as the whole sentence representation
print('mwe', mwe_hidden_state.size(), 'context', context_hidden_state.size())
similarity = torch.cosine_similarity(mwe_hidden_state, context_hidden_state, dim=1)
print('similarity', similarity.unsqueeze(1).size())
predictions = self.classifier(similarity.unsqueeze(1))
return predictions ```
RuntimeError Traceback (most recent call last)
<ipython-input-37-d5608ed79afa> in <cell line: 3>()
14 labels = batch['labels']
15
---> 16 outputs = model(mwe_ids, mwe_attn_mask, context_ids, context_attn_mask)
17 loss = loss_fn(outputs.squeeze(), labels)
18 loss.backward()
5 frames
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py in forward(self, input)
112
113 def forward(self, input: Tensor) -> Tensor:
--> 114 return F.linear(input, self.weight, self.bias)
115
116 def extra_repr(self) -> str:
RuntimeError: mat1 and mat2 shapes cannot be multiplied (4x1 and 768x4631)