Input shape of target mask in nn.Transformer

I am trying to use torch.nn.Transformer() to build a model of a transformer, but I have a question.

I want to batch input to the transformer function, so I tried to run the code with target_mask as (batch size, seq_len, seq_len), but I got the following runtime error

RuntimeError: The size of the 3D attn_mask is not correct.

The pytorch documentation says that the shape of the target_mask should be (T, T) (which means (seq_len, seq_len)), but there is no description of the batch size, so I’m not sure how to input the target_mask, so I want to know the shape of transformer’s target_mask shape with batch.

Sorry for my terribe English, I’m a junior high school student in japan :slight_smile: