Multiple batch dimensions

Some of pytorch’s build in modules have support for multiple ‘batch’ dimensions.
For example, the input of the nn.TransformerEncoder module has shape (n_words, batch, dim_in).
Although the first dim is not really a batch dim, it does not change the number of params of the model and can be modified after instantiation of the model.
I want similar functionality for the nn.Linear module. That is, I have input with shape (n_words, batch, dim_in) and I want dimensions 0 and 1 thought of as batch dimensions.
The simplest way to achieve this is to flatten the batch dimensions into 1, run the default forward, and then view/reshape the output back into the original shape.

Is there a better/faster way to do this? Will this even have a performance cost?

nn.Linear accepts inputs in the shape [batch_size, *, in_features] where * denotes additional dimensions.

1 Like

Oh god I didn’t even try. Thanks!