Pytorch transformer expected input

Hi All,
Let us say I have an input with the following dimensions [4,300,15,10] which correspond to [batch size, frames, objects, data for each object]. I want to use the transformer encoder layer and perform a self-attention on the objects using a transformer. since the input should include 3 dimensions I have folded the frame’s axis to the batch axis.
Digging in the transformer encoder layer I saw that the self-attention module expects to get the batch size as the second axis instead of the first, I tried to look if there is any flag that swaps the axises but couldn’t find one. In addition, when looking at the output weights of the self-attention module the shape of the weights tensor was [1200,1200,15] which looks like the attention was performed on the time axis (the first axis of the batch size * number of frames).

What should I do to perform the attention on the subjects? Do I need to explicitly swap between the first and second axises using transpose or it is done automatically inside the transformer encoder layer?

Thank you in advance :slight_smile:

@ptrblck maybe you can help me with it?