Embedding dimension mismatch in CVAE

I am trying to implement a conditional VAE but I’m getting the following error from the encoder embeddings:
RuntimeError: Tensors must have same number of dimensions: got 4 and 2
Here’s how I’ve implemented the encoder and decoder networks:

class Unet_Encoder(nn.Module):

    def __init__(self, num_classes, in_channels=3):

        super(Unet_Encoder, self).__init__()

        self.embed = nn.Embedding(num_classes, 256*256)

        self.down_1 = Unet_DownBlock(in_channels+1, 32, normalize=False)

        self.down_2 = Unet_DownBlock(32, 64)

        self.down_3 = Unet_DownBlock(64, 128)

        self.down_4 = Unet_DownBlock(128, 256)

        self.down_5 = Unet_DownBlock(256, 256)

        self.linear_encoder = nn.Linear(256 * 8 * 8, 512)

        self.dropout = nn.Dropout(0.5)

    def forward(self, x, labels):

        embedding = self.embed(labels)

        embedding = embedding.view(labels.shape[0], 1, 256, 256)

        x = torch.cat([x, embedding], dim=1)

        x = self.down_1(x)

        x = self.down_2(x)

        x = self.down_3(x)

        x = self.down_4(x)

        x = self.down_5(x)

        x = torch.flatten(x, start_dim=1)

        x = self.linear_encoder(x)

        x = self.dropout(x)

        return x

class Unet_Decoder(nn.Module):

    def __init__(self, num_classes, out_channels=3):

        super(Unet_Decoder, self).__init__()

        self. embed = nn.Embedding(num_classes, 128)

        self.linear_1 = nn.Linear(256, 8*8*256)

        self.dropout = nn.Dropout(0.5)

        self.deconv_1 = Unet_UpBlock(256, 256)

        self.deconv_2 = Unet_UpBlock(256, 128)

        self.deconv_3 = Unet_UpBlock(128, 64)

        self.deconv_4 = Unet_UpBlock(64, 32)

        self.final_image = nn.Sequential(*[nn.ConvTranspose2d(32, out_channels,

                                        kernel_size=4, stride=2,

                                        padding=1), nn.Tanh()])

    def forward(self, x, labels):

        embedding = self.embed(labels)

        x = torch.cat([x, embedding], dim=1)

        x = self.linear_1(x)

        x = x.view(-1, 256, 8, 8)

        x = self.dropout(x)

        x = self.deconv_1(x)

        x = self.deconv_2(x)

        x = self.deconv_3(x)

        x = self.deconv_4(x)

        x = self.final_image(x)

        return x

It is to be noted that when I’m testing the entire model on a dummy tensor and label it works just fine:

x = torch.randn((1,1,256,256))

    model = VAE_Model(num_classes=11)

    label = torch.tensor([5]) 

    out = model(x, label)

    print(label.shape) 

    print(out.shape)

Output:
torch.Size([1])
torch.Size([1, 1, 256, 256])

The issue is when I’m trying to train it on MNIST dataset I’m getting the said error.
I tried printing the shape of the labels and it is showing torch.Size([12]) which is correct because my batch size is 12. Where am I going wrong?