Hey All, I’m learning how to use pytorch and am wondering what the best way to do the following is:

in a custom module, I have an intermediate tensor that has a shape like [2,8,16], I’d like to apply a linear layer transform, say from 8 -> 4, in the middle dimension and get out a tensor of shape [2,4,16]. The obvious way to do it would be to use a for loop, but I was wondering if I ran this in a large network, with my layer in the GPU, would pytorch automagically parallelize this? Or perhaps I should extract the weights of the linear layer and formulate it as a matrix multiply? What’s the most efficient way to do this?

Thanks for any help