Some of pytorch’s build in modules have support for multiple ‘batch’ dimensions.
For example, the input of the nn.TransformerEncoder module has shape (n_words, batch, dim_in).
Although the first dim is not really a batch dim, it does not change the number of params of the model and can be modified after instantiation of the model.
I want similar functionality for the nn.Linear module. That is, I have input with shape (n_words, batch, dim_in) and I want dimensions 0 and 1 thought of as batch dimensions.
The simplest way to achieve this is to flatten the batch dimensions into 1, run the default forward, and then view/reshape the output back into the original shape.
Is there a better/faster way to do this? Will this even have a performance cost?