class Encoder(nn.Module):

def **init**(self):

super(Encoder, self).**init**()

```
self.model = nn.Sequential(
Conv2dBlock(1,64,(2,2),stride=1,padding=1,norm_fn= 'batchnorm',acti_fn= 'relu'),
Conv2dBlock(64, 128, (2, 2), stride=1, padding=1, norm_fn='batchnorm', acti_fn='relu'),
Conv2dBlock(128, 512, (2, 2), stride=1, padding=1, norm_fn='batchnorm', acti_fn='relu')
)
self.mu = LinearBlock(512, opt.latent_dim)
self.logvar = LinearBlock(512, opt.latent_dim)
def forward(self, img):
x= self.model(img)
mu = self.mu(x)
logvar = self.logvar(x)
z = reparameterization(mu, logvar)
return z
```

class Decoder(nn.Module):

def **init**(self):

super(Decoder, self).**init**()

```
self.model = nn.Sequential(
nn.Linear(opt.latent_dim, 512),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 512),
nn.BatchNorm1d(512),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, int(np.prod(img_shape))),
nn.Tanh(),
)
def forward(self, z):
# print(z.size()) torch.Size([256, 10])
img_flat = self.model(z)
# print(img_flat.size()) torch.Size([256, 1024])
img = img_flat.view(img_flat.shape[0], *img_shape)
# print(img.size())torch.Size([256, 1, 32, 32])
return img
```

class Discriminator(nn.Module):

def **init**(self):

super(Discriminator, self).**init**()

```
self.model = nn.Sequential(
nn.Linear(opt.latent_dim, 512),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 256),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(256, 1),
nn.Sigmoid(),
)
def forward(self, z):
validity = self.model(z)
return validity
```

# Use binary cross-entropy loss

adversarial_loss = torch.nn.BCELoss()

pixelwise_loss = torch.nn.L1Loss()

# Initialize generator and discriminator

encoder = Encoder()

decoder = Decoder()

discriminator = Discriminator()

if cuda:

encoder.cuda()

decoder.cuda()

discriminator.cuda()

adversarial_loss.cuda()

pixelwise_loss.cuda()