class Encoder(nn.Module):
def init(self):
super(Encoder, self).init()
self.model = nn.Sequential(
Conv2dBlock(1,64,(2,2),stride=1,padding=1,norm_fn= 'batchnorm',acti_fn= 'relu'),
Conv2dBlock(64, 128, (2, 2), stride=1, padding=1, norm_fn='batchnorm', acti_fn='relu'),
Conv2dBlock(128, 512, (2, 2), stride=1, padding=1, norm_fn='batchnorm', acti_fn='relu')
)
self.mu = LinearBlock(512, opt.latent_dim)
self.logvar = LinearBlock(512, opt.latent_dim)
def forward(self, img):
x= self.model(img)
mu = self.mu(x)
logvar = self.logvar(x)
z = reparameterization(mu, logvar)
return z
class Decoder(nn.Module):
def init(self):
super(Decoder, self).init()
self.model = nn.Sequential(
nn.Linear(opt.latent_dim, 512),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 512),
nn.BatchNorm1d(512),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, int(np.prod(img_shape))),
nn.Tanh(),
)
def forward(self, z):
# print(z.size()) torch.Size([256, 10])
img_flat = self.model(z)
# print(img_flat.size()) torch.Size([256, 1024])
img = img_flat.view(img_flat.shape[0], *img_shape)
# print(img.size())torch.Size([256, 1, 32, 32])
return img
class Discriminator(nn.Module):
def init(self):
super(Discriminator, self).init()
self.model = nn.Sequential(
nn.Linear(opt.latent_dim, 512),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 256),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(256, 1),
nn.Sigmoid(),
)
def forward(self, z):
validity = self.model(z)
return validity
Use binary cross-entropy loss
adversarial_loss = torch.nn.BCELoss()
pixelwise_loss = torch.nn.L1Loss()
Initialize generator and discriminator
encoder = Encoder()
decoder = Decoder()
discriminator = Discriminator()
if cuda:
encoder.cuda()
decoder.cuda()
discriminator.cuda()
adversarial_loss.cuda()
pixelwise_loss.cuda()