RuntimeError: mat1 and mat2 shapes cannot be multiplied (2x64 and 1024x1)

I would like to output two information based on the below unet model, a classification at image level using a linear layer and a segmentated image, but the linear layer does’nt work and I got this error: RuntimeError: mat1 and mat2 shapes cannot be multiplied (2x64 and 1024x1)


class UNet(nn.Module):

    def __init__(self, num_classes, in_channels=3, depth=5, start_filts=64, up_mode='transpose', merge_mode='concat'):
        super(UNet, self).__init__()

        self.up_mode = up_mode

        self.num_classes = num_classes
        self.in_channels = in_channels
        self.start_filts = start_filts
        self.depth = depth

        self.down_convs = []
        self.up_convs = []

        # create the encoder pathway and add to a list
        for i in range(depth):
            ins = self.in_channels if i == 0 else outs
            outs = self.start_filts*(2**i)
            pooling = True if i < depth-1 else False

            down_conv = DownConv(ins, outs, pooling=pooling)

        # create the decoder pathway and add to a list
        for i in range(depth-1):
            ins = outs
            outs = ins // 2
            up_conv = UpConv(ins, outs, up_mode=up_mode,merge_mode=merge_mode)

        self.conv_final = conv1x1(outs, 1, 1)

        # add the list of modules to current module
        self.down_convs = nn.ModuleList(self.down_convs)
        self.up_convs = nn.ModuleList(self.up_convs)
        self.linear = nn.Linear(1024, 1)
        self.activation = nn.ReLU()

    def forward(self, x):
        encoder_outs = []
        # encoder pathway, save outputs for merging
        for i, module in enumerate(self.down_convs):
            x, before_pool = module(x)

            x_ = torch.sum(self.activation(x), [2,3])
            out_linear = self.linear(x_)
    #def decoder(self)     
        for i, module in enumerate(self.up_convs):
            before_pool = encoder_outs[-(i+2)]
            x = module(before_pool, x)
        x = self.conv_final(x)
        return x, out_linear

You are trying to multiply a 2 by 64 to 1024 by 1. If you change the 1024 of your linear layer to 64 it will work