Hello there,
I am currently trying to create a MultiModal Emotion Recognition model using Bert and Audio Spectrogram Transformer but i ran into some issues when trying to train the data
the error code is as follows
11 frames
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)
1519
1520 def _call_impl(self, *args, **kwargs):
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1528
1529 try:
<ipython-input-23-1eec2900cfb9> in forward(self, text_input, audio_input)
8 def forward(self, text_input, audio_input):
9 text_output = self.text_model(**text_input).hidden_states[-1][:, 0, :]
---> 10 audio_output = self.audio_model(audio_input).last_hidden_state
11 concatenated = torch.cat((text_output, audio_output), dim=-1)
12 logits = self.classifier(concatenated)
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)
1519
1520 def _call_impl(self, *args, **kwargs):
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1528
1529 try:
/usr/local/lib/python3.10/dist-packages/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py in forward(self, input_values, head_mask, labels, output_attentions, output_hidden_states, return_dict)
571 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
572
--> 573 outputs = self.audio_spectrogram_transformer(
574 input_values,
575 head_mask=head_mask,
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)
1519
1520 def _call_impl(self, *args, **kwargs):
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1528
1529 try:
/usr/local/lib/python3.10/dist-packages/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py in forward(self, input_values, head_mask, output_attentions, output_hidden_states, return_dict)
488 head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
489
--> 490 embedding_output = self.embeddings(input_values)
491
492 encoder_outputs = self.encoder(
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)
1519
1520 def _call_impl(self, *args, **kwargs):
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1528
1529 try:
/usr/local/lib/python3.10/dist-packages/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py in forward(self, input_values)
85 distillation_tokens = self.distillation_token.expand(batch_size, -1, -1)
86 embeddings = torch.cat((cls_tokens, distillation_tokens, embeddings), dim=1)
---> 87 embeddings = embeddings + self.position_embeddings
88 embeddings = self.dropout(embeddings)
89
RuntimeError: The size of tensor a (146) must match the size of tensor b (1214) at non-singleton dimension 1
Here it says that there is a tensor mismatch when trying to run
logits = multimodal_model(text_input, audio_input)
my multimodal_model is as follows
multimodal_model = MultimodalModel(text_model, ast_model, num_classes)
multimodal_model.to(device)
and the MultimodalModel is as follows
class MultimodalModel(nn.Module):
def __init__(self, text_model, audio_model, num_classes):
super(MultimodalModel, self).__init__()
self.text_model = text_model
self.audio_model = audio_model
self.classifier = nn.Linear(text_model.config.hidden_size + audio_model.config.hidden_size, num_classes)
def forward(self, text_input, audio_input):
text_output = self.text_model(**text_input).hidden_states[-1][:, 0, :]
audio_output = self.audio_model(audio_input).last_hidden_state
concatenated = torch.cat((text_output, audio_output), dim=-1)
logits = self.classifier(concatenated)
return logits
From the error code it can be seen that an error appeared when trying to run
audio_output = self.audio_model(audio_input).last_hidden_state
which makes me believe that the audio model is rejecting it like it is said in the error code where the missmatch occurs in
/usr/local/lib/python3.10/dist-packages/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py
for reference this is my audio data shape
Train: torch.Size([9887, 128, 128])
Test: torch.Size([1094, 128, 128])
and my text data comes from this code
train_text_encoded = text_tokenizer.batch_encode_plus(train_data, truncation=True, padding=True, max_length=100, return_tensors="pt")
test_text_encoded = text_tokenizer.batch_encode_plus(test_data, truncation=True, padding=True, max_length=100, return_tensors="pt")
From these snippets what are the fault in my data / code that could’ve caused this error?
If there are any more code needed to help debug this i would be glad to show the code
Thank you