I want to replace the encode part of the AAE with a convolutional layer, but it doesn't work anymore. Can you help me? Thanks a lot

class Encoder(nn.Module):
def init(self):
super(Encoder, self).init()

    self.model  = nn.Sequential(
        Conv2dBlock(1,64,(2,2),stride=1,padding=1,norm_fn= 'batchnorm',acti_fn= 'relu'),
        Conv2dBlock(64, 128, (2, 2), stride=1, padding=1, norm_fn='batchnorm', acti_fn='relu'),
         Conv2dBlock(128, 512, (2, 2), stride=1, padding=1, norm_fn='batchnorm', acti_fn='relu')
    self.mu = LinearBlock(512, opt.latent_dim)
    self.logvar = LinearBlock(512, opt.latent_dim)

def forward(self, img):
    x= self.model(img)
    mu = self.mu(x)
    logvar = self.logvar(x)
    z = reparameterization(mu, logvar)
    return z

class Decoder(nn.Module):
def init(self):
super(Decoder, self).init()

    self.model = nn.Sequential(
        nn.Linear(opt.latent_dim, 512),
        nn.LeakyReLU(0.2, inplace=True),
        nn.Linear(512, 512),
        nn.LeakyReLU(0.2, inplace=True),
        nn.Linear(512, int(np.prod(img_shape))),

def forward(self, z):
    # print(z.size()) torch.Size([256, 10])
    img_flat = self.model(z)
    # print(img_flat.size()) torch.Size([256, 1024])
    img = img_flat.view(img_flat.shape[0], *img_shape)
    # print(img.size())torch.Size([256, 1, 32, 32])
    return img

class Discriminator(nn.Module):
def init(self):
super(Discriminator, self).init()

    self.model = nn.Sequential(
        nn.Linear(opt.latent_dim, 512),
        nn.LeakyReLU(0.2, inplace=True),
        nn.Linear(512, 256),
        nn.LeakyReLU(0.2, inplace=True),
        nn.Linear(256, 1),

def forward(self, z):
    validity = self.model(z)
    return validity

Use binary cross-entropy loss

adversarial_loss = torch.nn.BCELoss()
pixelwise_loss = torch.nn.L1Loss()

Initialize generator and discriminator

encoder = Encoder()
decoder = Decoder()
discriminator = Discriminator()

if cuda:

What exactly are you changing and what is not working?
Could you post the error message you are seeing?

Also, you can post code snippets by wrapping them into three backticks ```, which makes debugging easier. :wink:

Thank you for your reply. I have solved this problem, but I still have a small problem to solve. I have two loss formulas that need to be coded with pytorch, but I have no clue now. Can you help me? Thank you very much