I am trying to use torch.nn.Transformer()
to build a model of a transformer, but I have a question.
I want to batch input to the transformer function, so I tried to run the code with target_mask
as (batch size, seq_len, seq_len)
, but I got the following runtime error
RuntimeError: The size of the 3D attn_mask is not correct.
The pytorch documentation says that the shape of the target_mask
should be (T, T)
(which means (seq_len, seq_len)
), but there is no description of the batch size, so I’m not sure how to input the target_mask, so I want to know the shape of transformer’s target_mask shape with batch.
Sorry for my terribe English, I’m a junior high school student in japan
I ran into the same problem. I’ve looked into the source code and found that if the attention_mask (target_mask here) is 3D tensor, num_heads should also be taken into consideration. So instead of (batch_size, seq_len, seq_len), it should be (batch_size * num_heads, seq_len, seq_len).
Here’s the documentation in the source code:
attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape
:math:(L, S)
or :math:(N\cdot\text{num\_heads}, L, S)
, where :math:N
is the batch size,
:math:L
is the target sequence length, and :math:S
is the source sequence length. A 2D mask will be
broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch.
Binary, byte, and float masks are supported. For a binary mask, a True
value indicates that the
corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the
corresponding position is not allowed to attend. For a float mask, the mask values will be added to
the attention weight.
1 Like