AutoEncoder Image Reconstruction

I was training a Auto Encoder Image Reconstruction model for finding the anomaly in the surveillance images, I used EfficientNetv2-s as Encoder and Se Block + Residual Block + Transpose Convolution Layers containing Custom Decoder.

But what is the model is unable to train, the loss is stagnant from the 5th epoch to 45th epoch…

Do i miss anything That i need to do while training Auto Encoder model

Visualization Report : Visualization report : Weights & Biases

Codes :
Encoder

class Encoder(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = models.efficientnet_b3(include_top=False, pretrained=False)
        self.model.classifier = nn.Identity()
        
    def forward(self, x):
        return self.model(x)

# Encoder
class EfficientNetv2Encoder(pl.LightningModule):
    def __init__(self):
        super(EfficientNetv2Encoder, self).__init__()
        self.file_path = utils.ROOT_PATH + '/weights/EfficientNetv2Encoder'
        self.example_input_array = torch.rand(1, 3, 256, 256)
        self.example_output_array = torch.rand(1, 1280)
        self.save_hyperparameters()
        self.model = Encoder()
        try:
            self.model = torch.load(utils.ROOT_PATH + '/weights/EfficientNetv2Encoder.pt')
            print("Encoder Model Found")
        except Exception as e:
            self.model = models.efficientnet_v2_s(include_top=False, weights='EfficientNet_V2_S_Weights.DEFAULT')
            self.model.classifier = nn.Identity()
            torch.save(self.model, utils.ROOT_PATH + '/weights/EfficientNetv2Encoder.pt')

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = nn.CrossEntropyLoss()(y_hat, y)
        acc = (y_hat.argmax(dim=1) == y).float().mean()
        self.log('train_loss', loss)
        self.log('train_acc', acc)
        return loss

Decoder:


class SEAttention(nn.Module):
    def __init__(self, in_channels, reduction=16):
        super(SEAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(in_channels, in_channels // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(in_channels // reduction, in_channels, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y
    
    
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()

        # Initial representation
        self.fc = nn.Linear(1280, 4*4*1024)
        self.bn1d = nn.BatchNorm1d(4*4*1024)
        self.gelu = nn.GELU()

        # Decoder layers
        self.conv1 = nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2, padding=1, output_padding=0)
        self.bn1 = nn.BatchNorm2d(512)
        self.relu1 = nn.ReLU()

        self.conv2 = nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, output_padding=0)
        self.bn2 = nn.BatchNorm2d(256)
        self.relu2 = nn.ReLU()

        self.conv3 = nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, output_padding=0)
        self.bn3 = nn.BatchNorm2d(128)
        self.relu3 = nn.ReLU()

        self.conv4 = nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, output_padding=0)
        self.bn4 = nn.BatchNorm2d(64)
        self.relu4 = nn.ReLU()

        self.conv5 = nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1, output_padding=0)
        self.bn5 = nn.BatchNorm2d(32)
        self.relu5 = nn.ReLU()

        self.conv6 = nn.ConvTranspose2d(32, 32, kernel_size=4, stride=2, padding=1, output_padding=0)
        self.bn6 = nn.BatchNorm2d(32)
        self.relu6 = nn.ReLU()

        # Residual blocks with SE attention
        self.res2 = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.Sigmoid(),
            SEAttention(64),
            nn.ReLU()
        )

        self.res1 = nn.Sequential(
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.Sigmoid(),
            SEAttention(256),
            nn.ReLU()
        )
        
        self.conv7 = nn.Conv2d(32, 3, kernel_size=1, stride=1, padding=0)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.fc(x)
        x = self.bn1d(x)
        x = self.gelu(x)
        x = x.view(-1, 1024, 4, 4)

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu2(x)

        x = self.res1(x) + x


        x = self.conv3(x)
        x = self.bn3(x)
        x = self.relu3(x)

        x = self.conv4(x)
        x = self.bn4(x)
        x = self.relu4(x)

        x = self.res2(x) + x

        x = self.conv5(x)
        x = self.bn5(x)
        x = self.relu5(x)

        x = self.conv6(x)
        x = self.bn6(x)
        x = self.relu6(x)

        x = self.conv7(x)
        x = self.sigmoid(x)

        return x
    

class EfficientNetv2Decoder(pl.LightningModule):
    def __init__(self):
        super(EfficientNetv2Decoder, self).__init__()
        self.model = Decoder()
        try:
            self.model = torch.load(utils.ROOT_PATH + '/weights/EfficientNetv2DecoderLarge.pt')
            print("Decoder Weights Found")
        except Exception as e:
            torch.save(self.model, utils.ROOT_PATH + '/weights/EfficientNetv2DecoderLarge.pt')

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = nn.MSELoss()(y_hat, y)
        self.log('train_loss', loss)
        return loss
    

Auto Encoder:

class AutoEncoder(pl.LightningModule):
    def __init__(self, 
                    ) -> None:
        super(AutoEncoder, self).__init__()
        self.example_input_array = torch.zeros(1, 3, 256, 256)
        self.save_hyperparameters()
        self.encoder = EfficientNetv2Encoder()
        self.decoder = EfficientNetv2Decoder()

    def forward(self, x):
        try:
            x = self.encoder(x)
            x = self.decoder(x)
        except Exception as e:
            print(e, "Error!")
            x = torch.rand(128, 3, 256, 256)
            x = self.encoder(x)
            x = self.decoder(x)
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        x = x.view(x.size(1), x.size(2), x.size(3), x.size(4)).half()
        y = y.view(y.size(1), y.size(2), y.size(3), y.size(4)).half()
        batch_size = x.shape[0]
        y_hat = self(x)
        loss = nn.MSELoss()(y_hat, y)
        self.log('train_loss', loss)
        return {"loss" : loss}

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
        return [optimizer], [scheduler]