I have a
hidden_state tuple of tensors received from model’s
output (hugging-face model), where each tensor has shape:
[batch_size, seq_len, hid_size].
I also have
attention_mask tensor received from batch samples (from tokenizer) with shape:
I want to apply the mask on the hidden_states such that I will only extract hidden_states’ elements where mask is True. The extraction should not influence the hid_size dimension, only the 2 first dimensions. I am looking for something like tf.boolean_mask().
I plan to flatten the 2 first dimensions together afterwards.
I have managed to do:
torch.where(mask==1, hidden_states, mask)
but it returns output from the same dimensions, with zeros where mask is False. I need to remove the zero elements without changing the dimensions.
I would appreciate any help!