I am trying to use torch.nn.Transformer()
to build a model of a transformer, but I have a question.
I want to batch input to the transformer function, so I tried to run the code with target_mask
as (batch size, seq_len, seq_len)
, but I got the following runtime error
RuntimeError: The size of the 3D attn_mask is not correct.
The pytorch documentation says that the shape of the target_mask
should be (T, T)
(which means (seq_len, seq_len)
), but there is no description of the batch size, so I’m not sure how to input the target_mask, so I want to know the shape of transformer’s target_mask shape with batch.
Sorry for my terribe English, I’m a junior high school student in japan