I am trying to implement a conditional VAE but I’m getting the following error from the encoder embeddings:
RuntimeError: Tensors must have same number of dimensions: got 4 and 2
Here’s how I’ve implemented the encoder and decoder networks:
class Unet_Encoder(nn.Module):
def __init__(self, num_classes, in_channels=3):
super(Unet_Encoder, self).__init__()
self.embed = nn.Embedding(num_classes, 256*256)
self.down_1 = Unet_DownBlock(in_channels+1, 32, normalize=False)
self.down_2 = Unet_DownBlock(32, 64)
self.down_3 = Unet_DownBlock(64, 128)
self.down_4 = Unet_DownBlock(128, 256)
self.down_5 = Unet_DownBlock(256, 256)
self.linear_encoder = nn.Linear(256 * 8 * 8, 512)
self.dropout = nn.Dropout(0.5)
def forward(self, x, labels):
embedding = self.embed(labels)
embedding = embedding.view(labels.shape[0], 1, 256, 256)
x = torch.cat([x, embedding], dim=1)
x = self.down_1(x)
x = self.down_2(x)
x = self.down_3(x)
x = self.down_4(x)
x = self.down_5(x)
x = torch.flatten(x, start_dim=1)
x = self.linear_encoder(x)
x = self.dropout(x)
return x
class Unet_Decoder(nn.Module):
def __init__(self, num_classes, out_channels=3):
super(Unet_Decoder, self).__init__()
self. embed = nn.Embedding(num_classes, 128)
self.linear_1 = nn.Linear(256, 8*8*256)
self.dropout = nn.Dropout(0.5)
self.deconv_1 = Unet_UpBlock(256, 256)
self.deconv_2 = Unet_UpBlock(256, 128)
self.deconv_3 = Unet_UpBlock(128, 64)
self.deconv_4 = Unet_UpBlock(64, 32)
self.final_image = nn.Sequential(*[nn.ConvTranspose2d(32, out_channels,
kernel_size=4, stride=2,
padding=1), nn.Tanh()])
def forward(self, x, labels):
embedding = self.embed(labels)
x = torch.cat([x, embedding], dim=1)
x = self.linear_1(x)
x = x.view(-1, 256, 8, 8)
x = self.dropout(x)
x = self.deconv_1(x)
x = self.deconv_2(x)
x = self.deconv_3(x)
x = self.deconv_4(x)
x = self.final_image(x)
return x
It is to be noted that when I’m testing the entire model on a dummy tensor and label it works just fine:
x = torch.randn((1,1,256,256))
model = VAE_Model(num_classes=11)
label = torch.tensor([5])
out = model(x, label)
print(label.shape)
print(out.shape)
Output:
torch.Size([1])
torch.Size([1, 1, 256, 256])
The issue is when I’m trying to train it on MNIST dataset I’m getting the said error.
I tried printing the shape of the labels and it is showing torch.Size([12]) which is correct because my batch size is 12. Where am I going wrong?