Input shape of target mask in nn.Transformer

I am trying to use torch.nn.Transformer() to build a model of a transformer, but I have a question.

I want to batch input to the transformer function, so I tried to run the code with target_mask as (batch size, seq_len, seq_len), but I got the following runtime error

RuntimeError: The size of the 3D attn_mask is not correct.

The pytorch documentation says that the shape of the target_mask should be (T, T) (which means (seq_len, seq_len)), but there is no description of the batch size, so I’m not sure how to input the target_mask, so I want to know the shape of transformer’s target_mask shape with batch.

Sorry for my terribe English, I’m a junior high school student in japan :slight_smile:

I ran into the same problem. I’ve looked into the source code and found that if the attention_mask (target_mask here) is 3D tensor, num_heads should also be taken into consideration. So instead of (batch_size, seq_len, seq_len), it should be (batch_size * num_heads, seq_len, seq_len).
Here’s the documentation in the source code:

attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape
:math:(L, S) or :math:(N\cdot\text{num\_heads}, L, S), where :math:N is the batch size,
:math:L is the target sequence length, and :math:S is the source sequence length. A 2D mask will be
broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch.
Binary, byte, and float masks are supported. For a binary mask, a True value indicates that the
corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the
corresponding position is not allowed to attend. For a float mask, the mask values will be added to
the attention weight.

1 Like