Doing a Linear Transform over a dimension

Hey All, I’m learning how to use pytorch and am wondering what the best way to do the following is:

in a custom module, I have an intermediate tensor that has a shape like [2,8,16], I’d like to apply a linear layer transform, say from 8 -> 4, in the middle dimension and get out a tensor of shape [2,4,16]. The obvious way to do it would be to use a for loop, but I was wondering if I ran this in a large network, with my layer in the GPU, would pytorch automagically parallelize this? Or perhaps I should extract the weights of the linear layer and formulate it as a matrix multiply? What’s the most efficient way to do this?

Thanks for any help

You could permute the dimensions and just apply a linear layer:

x = torch.randn(2, 8, 16)
x = x.permute(0, 2, 1).contiguous()
lin = nn.Linear(8, 4)
output = lin(x)

The docs explain the behavior of the input and output shapes.

2 Likes

You could use Conv1d with kernel_size=1.

1 Like