How to use nn.Linear to transform the specified dimension of a tensor?

Hi,

Usually, we use torch.nn.Linear to transform a tensor, for example:

import torch
import torch.nn as nn
x = torch.rand(16,32,64)
W = nn.Linear(64,128)
x = W(x) 

Then we can get x with shape [16, 32, 128]

Now, the question is, what should I do if I want to transform the second dimension of x? That is, transform x from [16,32,64] to [16,the_size_I_want,64]

Thanks!

You can simply swap the dimensions to send 32 to the end before the operation, and then swap it back to the original position after computing the linear projection.

import torch
import torch.nn as nn

x = torch.rand(16,32,64)
W1 = nn.Linear(32,128)
y = W1(x.permute(0,2,1)).permute(0,2,1)
print(y.shape)
# (16, 128, 64)

Alternatively, you can do a trick. A 1d convolution is equivalent to nn.Linear on transposed input when the kernel size is 1. So you can also do:

W2 = nn.Conv1d(32, 128, 1)
z = W2(x)
print(z.shape)
# (16, 128, 64)

The number of weight and bias parameters of W1 and W2 are also the same. You can check with W1.weight.data.shape and W1.bias.data.shape, and similarly for W2.

1 Like

Thank you! That’s cool!

1 Like