nn.Linear dimensions

Hi, I am quite new to PyTorch and am a bit confused about how to do a certain task with nn.Linear.

I have an input tensor of size 2x28, and I want the output of the Linear layer to be 28 x s where s can be any scalar value.

How can this be achieved? Right now, from my understanding, the input of a 2x28 tensor with s=3 results in the output of a 2x3 tensor. How would I make it so that the output is 28x3?

Here’s my basic code so far:

import torch.nn as nn

x = torch.rand(2, 28)
s = 3


linear= nn.Linear(28, s)
out = linear(x)

print('out is:', out.shape)


torch.Size([2, 28])
out is: torch.Size([2, 3])

Thank you!

So basically what happens is the first number that is passed to the linear layers is the batch size. The way the linear layer works is it computes the outputs over all batches so the batch number doesn’t change only the input number. If you want something of shape [28, s] you would need to put the 28 first as the batch size.