Hi,
I have a hidden_state
tuple of tensors received from model’s output
(hugging-face model), where each tensor has shape: [batch_size, seq_len, hid_size].
I also have attention_mask
tensor received from batch samples (from tokenizer) with shape: [batch_size, seq_len]
.
I want to apply the mask on the hidden_states such that I will only extract hidden_states’ elements where mask is True. The extraction should not influence the hid_size dimension, only the 2 first dimensions. I am looking for something like tf.boolean_mask().
I plan to flatten the 2 first dimensions together afterwards.
I have managed to do:
torch.where(mask==1, hidden_states, mask)
but it returns output from the same dimensions, with zeros where mask is False. I need to remove the zero elements without changing the dimensions.
I would appreciate any help!
Thanks