import torch.nn as nn
import torch.nn.functional as F
class Encoder(nn.Module):
def __init__(self, latent_dims):
# This part of code contains all the definations
# of the stuffs that we are going to use in the
# model
super(Encoder, self).__init__()
self.conv1 = nn.Conv2d(1, 16, 3, padding=1)
self.batch_norm1 = nn.BatchNorm2d(16)
self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
self.batch_norm2 = nn.BatchNorm2d(32)
self.conv3 = nn.Conv2d(32, 64, 3, padding=1)
self.batch_norm3 = nn.BatchNorm2d(64)
self.pool = nn.MaxPool2d(2, 2)
self.flatten = nn.Flatten(start_dim=1)
self.linear1 = nn.Linear(4096, latent_dims)
def forward(self, x):
x = F.relu(self.conv1(x))
x = self.batch_norm1(x)
x = self.pool(x)
x = F.relu(self.conv2(x))
x = self.batch_norm2(x)
x = self.pool(x)
x = F.relu(self.conv3(x))
x = self.batch_norm3(x)
x = self.pool(x)
x = self.flatten(x)
return F.softmax(self.linear1(x))
class Decoder(nn.Module):
def __init__(self, latent_dims):
super(Decoder, self).__init__()
self.linear1 = nn.Linear(latent_dims, 4096)
self.unflatten = nn.Unflatten(1, (128, 8, 8))
self.conv_t_2d_1 = nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1)
self.conv_t_2d_2 = nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1)
self.conv_t_2d_3 = nn.ConvTranspose2d(32, 1, 3, stride=2, padding=1, output_padding=1)
# self.conv_2d_1 = nn.Conv2d(16, 1, 3, stride=2, padding=1)
def forward(self, z):
z = F.relu(self.linear1(z))
# z = z.view(-1, z.size( 1 ))
z = z.view(1, 8192)
# print(z.size())
# print(z.view(-1, z.shape[-1]).shape[0])
z = self.unflatten(z)
# print(z.size())
z = F.relu(self.conv_t_2d_1(z))
# print(z.size())
z = F.relu(self.conv_t_2d_2(z))
# print(z.size())
z = self.conv_t_2d_3(z)
# print(z.size())
return torch.sigmoid(z)#z.reshape((-1, 1, 64, 64))
class Autoencoder(nn.Module):
def __init__(self, latent_dims):
super(Autoencoder, self).__init__()
self.encoder = Encoder(latent_dims)
self.decoder = Decoder(latent_dims)
def forward(self, x):
z = self.encoder(x)
return self.decoder(z)
summary(autoencoder,input_size=(1,64,64))
When we change the batch size it doesn’t works, except for batch size 2.