Model outputs a single prediction instead of predictions for whole batch

I defined a simple convolutional model to try something.
Unfortunately, I can’t figure out why the following only outputs a single prediction instead of a vector of predictions with size batch_size.

class ShapeNet(LightningModule):

  def __init__(self, in_channels=1, hidden_dim=16, n_classes=3, lr = 0.0001):

    self.lr = lr

    super(ShapeNet, self).__init__()

    self.convolution = nn.Sequential(OrderedDict([

        ("conv1", self.conv_block(in_channels, hidden_dim)),

        ("conv2",self.conv_block(hidden_dim, hidden_dim * 2)),

        ("final_conv",self.conv_block(hidden_dim * 2, 1, final_layer=True)),

    ]))

    self.fc = nn.Sequential(OrderedDict([

        ("fc", nn.Linear(in_features = 3136, out_features = n_classes)),

        ("softmax", nn.Sigmoid())                        

    ]))

   

  def conv_block(self, in_channels, out_channels, kernel_size = 3, stride = 2, final_layer=False):

    if not final_layer:

            return nn.Sequential(OrderedDict([

                ("Conv2d", nn.Conv2d(in_channels, out_channels, kernel_size, stride)),

                ("BatchNorm", nn.BatchNorm2d(out_channels)),

                ("MaxPool", nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))),

                ("LeakyRelu", nn.LeakyReLU(0.2, inplace=True))

            ])

            )

    else: 

        return nn.Sequential(OrderedDict([

            ("final", nn.Conv2d(in_channels, out_channels, kernel_size, stride))

        ])

        )

  def forward(self, x):

    # pass the input through our first set of CONV => RELU =>

    # POOL layers

    x = self.convolution(x.float())

    x = flatten(x)

    x = self.fc(x)

    return x

Would someone be able to tell me where I’m going wrong here?

when I call model(x) with x being of shape(batch_size, 1, width, height)
I’d expect a vector of predictions of shape (batch_size, n_classes).
But I only get (n_classes)

thanks in advance!

I figured it out, coming from tensorflow I had thought that flatten will keep batch_size, but I needed to pass the argument start_dim = 1.