Apply 2D attention mask on 3D hidden_states


I have a hidden_state tuple of tensors received from model’s output (hugging-face model), where each tensor has shape: [batch_size, seq_len, hid_size].

I also have attention_mask tensor received from batch samples (from tokenizer) with shape: [batch_size, seq_len].

I want to apply the mask on the hidden_states such that I will only extract hidden_states’ elements where mask is True. The extraction should not influence the hid_size dimension, only the 2 first dimensions. I am looking for something like tf.boolean_mask().

I plan to flatten the 2 first dimensions together afterwards.

I have managed to do:

torch.where(mask==1, hidden_states, mask)

but it returns output from the same dimensions, with zeros where mask is False. I need to remove the zero elements without changing the dimensions.

I would appreciate any help!

just try a regular indexing.
something like hidden_states[mask, :]

Alternatively, you can do two chained calls to masked_select