Variational Auto Encoder Not able to reconstruct the Surveillance image

I am using CoAtNet with 55M parameter as Encoder and Custom Decoder of 30 M parameter to reconstruct the Surveillance image so as to the get the feature rich latent vector. I tried EfficientNet ResNet all kind of version within 125M parameters and its same for all the models. The model is dropping from 2.2 cross entropy loss to 1.2 , 1.1, 1.0 based on the model complexity and oscillating in it.

Visualization for recent model :

For Every model and the data , image is reconstructed similarly and no further improvement after 2K steps.

UCF-Crime dataset :- from all the video, i extract 4/8 images and send it as a batch and accumulate gradiets for 64/32 batches which implies effective batch_size of 256 with different scenarios.

Every dataloader send it as 1 batch of size (1, 4, 3, 256, 256) and in forward function this is squezzed to represent (4, 3, 256,256) as 4 batches.

MAE loss & KLD loss were used to train … KLD loss is decreasing and works fine. But reconstruction loss is fully constant. And the image generated is very worst.(see in wandb visualization)

Note : i tried MAE, BCE loss, MSE loss & 1e-3, 1e-4, 1e-5 and 1e-4 gives best results for MAE loss, paramters from 20M - 150M.

Encoder Model:

class EncoderCoAtNet(pl.LightningModule):
    def __init__(self, weights_path="/weights/EncoderCoAtNet", 
                 num_blocks=[2, 2, 6, 14, 2], channels=[128, 128, 256, 512, 1024]):
        super(EncoderCoAtNet, self).__init__()
        self.model = CoAtNet(num_blocks=num_blocks, channels=channels)
        self.weights_path = weights_path
        self.example_input_array = torch.randn(1, 3, 256, 256)
        self.example_output_array = torch.randn(1, 1024)
        self.save_hyperparameters()
        self.best_val_loss = None
        try:
            self.model = torch.load(utils.ROOT_PATH + self.weights_path + '.pt')
        except FileNotFoundError:
            torch.save(self.model, utils.ROOT_PATH + self.weights_path + '.pt')
        
    def forward(self, x):
        x = self.model(x)
        return x

Decoder Model :

class SEAttention(nn.Module):
    def __init__(self, in_channels, reduction=16):
        super(SEAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(in_channels, in_channels // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(in_channels // reduction, in_channels, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y

class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()

        # Initial representation
        self.fc = nn.Linear(1024, 4*4*1024)
        self.bn1d = nn.BatchNorm1d(4*4*1024)
        self.gelu = nn.GELU()

        # Decoder layers
        self.conv1 = nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2, padding=1, output_padding=0)
        self.bn1 = nn.BatchNorm2d(512)
        self.relu1 = nn.ReLU()

        self.conv2 = nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, output_padding=0)
        self.bn2 = nn.BatchNorm2d(256)
        self.relu2 = nn.ReLU()

        self.conv3 = nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, output_padding=0)
        self.bn3 = nn.BatchNorm2d(128)
        self.relu3 = nn.ReLU()

        self.conv4 = nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, output_padding=0)
        self.bn4 = nn.BatchNorm2d(64)
        self.relu4 = nn.ReLU()

        self.conv5 = nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1, output_padding=0)
        self.bn5 = nn.BatchNorm2d(32)
        self.relu5 = nn.ReLU()

        self.conv6 = nn.ConvTranspose2d(32, 16, kernel_size=4, stride=2, padding=1, output_padding=0)
        self.bn6 = nn.BatchNorm2d(16)
        self.relu6 = nn.ReLU()

        # Residual blocks with SE attention
        self.res2 = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.Sigmoid(),
            SEAttention(64),
            nn.ReLU()
        )

        self.res1 = nn.Sequential(
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.Sigmoid(),
            SEAttention(256),
            nn.ReLU()
        )

        self.dropout = nn.Dropout(0.25)
        
        self.conv7 = nn.Conv2d(16, 3, kernel_size=3, stride=1, padding=1)
        self.tanh = nn.Tanh()

    def forward(self, x):
        x = self.fc(x)
        x = self.bn1d(x)
        x = self.dropout(x)
        x = self.gelu(x)
        x = x.view(-1, 1024, 4, 4)

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.dropout(x)
        x = self.relu2(x)

        x = self.res1(x) + x


        x = self.conv3(x)
        x = self.bn3(x)
        x = self.relu3(x)

        x = self.conv4(x)
        x = self.bn4(x)
        x = self.dropout(x)
        x = self.relu4(x)

        x = self.res2(x) + x

        x = self.conv5(x)
        x = self.bn5(x)
        x = self.relu5(x)

        x = self.conv6(x)
        x = self.bn6(x)
        x = self.relu6(x)

        x = self.conv7(x)
        x = self.tanh(x)

        return x


class EfficientNetv2Decoder(pl.LightningModule):
    def __init__(self):
        super(EfficientNetv2Decoder, self).__init__()
        self.model = Decoder()
        try:
            self.model = torch.load(utils.ROOT_PATH + '/weights/EfficientNetv2DecoderLarge.pt')
            print("Decoder Weights Found")
        except Exception as e:
            torch.save(self.model, utils.ROOT_PATH + '/weights/EfficientNetv2DecoderLarge.pt')

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = nn.MSELoss()(y_hat, y)
        self.log('train_loss', loss)
        return loss

VariationEncoder :

class EfficientnetV2VarEncoder(pl.LightningModule):
    def __init__(self):
        super(EfficientnetV2VarEncoder, self).__init__()
        self.file_path = utils.ROOT_PATH + '/weights/EfficientNetv2VE'
        self.encoder = EncoderCoAtNet()
        self.latent_dim = 1024
        self.example_input_array = torch.rand(1, 3, 256, 256)
        self.example_output_array = torch.rand(1, 1024)
        self.save_hyperparameters()
        self.fc_mu = nn.Linear(1024, self.latent_dim)
        self.fc_var = nn.Linear(1024, self.latent_dim)

        try:
            torch.save(self, utils.ROOT_PATH + '/weights/' + 'VE.pt')
        except Exception as e:
            self.encoder = torch.load(utils.ROOT_PATH + '/weights/' + 'VE.pt').encoder
            self.fc_mu = torch.load(utils.ROOT_PATH + '/weights/' + 'VE.pt').fc_mu
            self.fc_var = torch.load(utils.ROOT_PATH + '/weights/' + 'VE.pt').fc_var

        try:
            self.fc_mu.load_state_dict(torch.load( utils.ROOT_PATH + '/weights/' + 'fc_mu.pth'))
        except Exception as e:
            torch.save(self.fc_mu.state_dict(),  utils.ROOT_PATH + '/weights/' + 'fc_mu.pth')

        try:
            self.fc_var.load_state_dict(torch.load(utils.ROOT_PATH + '/weights/' + 'fc_var.pth'))
        except Exception as e:
            torch.save(self.fc_var.state_dict(), utils.ROOT_PATH + '/weights/' + 'fc_var.pth')
        
    def forward(self, x):
        x = self.encoder(x)
        mu = self.fc_mu(x)
        log_var = self.fc_var(x)
        return mu, log_var
    
    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5*log_var)
        eps = torch.randn_like(std)
        return mu + eps*std

VAE Model :-

pl.seed_everything(42)
    
class VariationalAutoEncoder(pl.LightningModule):
    def __init__(self, 
                    ) -> None:
        super(VariationalAutoEncoder, self).__init__()
        self.example_input_array = torch.zeros(1, 3, 256, 256)
        self.save_hyperparameters()
        self.encoder = EfficientnetV2VarEncoder()
        self.decoder = EfficientNetv2Decoder()
        self.encoder.train()
        self.decoder.train()
        self.latent_dim = 1024
        self.beta = 0
        
    def forward(self, x):
        try:
            mu, var = self.encoder(x)
            z = self.encoder.reparameterize(mu, var)
            x = self.decoder(z)
        except Exception as e:
            print(e, "Error!")
            x = torch.rand(4, 3, 256, 256)
            mu, var = self.encoder(x)
            z = self.encoder.reparameterize(mu, var)
            x = self.decoder(z)
        
        return x, mu, var

    def loss_function(self, recon_x, x, mu, logvar):
        MAE = nn.functional.l1_loss(recon_x, x, reduction='none')
        MAE = MAE.view(MAE.size(0), -1).mean(dim=1)
        KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1)
        
        if self.beta > 1:
            self.beta = 1
        self.log("loss/MAE_loss", MAE)
        self.log("loss/kld_loss", KLD)
        self.log("losss/beta", self.beta)
        self.log("Total", self.beta*KLD + MAE)
        loss = MAE + self.beta*KLD
        loss = loss.mean()
        return loss

    def training_step(self, batch, batch_idx):
        x, y = batch
        self.beta += 0.001
        x = x.view(x.size(1), x.size(2), x.size(3), x.size(4)).half()
        y = y.view(y.size(1), y.size(2), y.size(3), y.size(4)).half()
        x_hat, mu, log_var = self(x)
        print(x_hat.shape)
        loss = self.loss_function(x_hat, y, mu, log_var)
        self.log('train_loss', loss)
        if batch_idx % 1000 == 0:
            self.log_image(x, x_hat, y)
        return {"loss" : loss}

    def training_epoch_end(self, outputs)-> None:
        loss = outputs[0]['loss']
        try:
            avg_loss = torch.stack([x['loss'] for x in loss]).mean()
        except TypeError:
            avg_loss = loss
        self.log('train/loss_epoch', avg_loss)

  def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.0001)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
        return [optimizer], [scheduler]    

DatasetLoader:

class AutoEncoderDataset(Dataset):
    def __init__(self, batch_size:int,
                    data_path, annotation) -> None:
        super(AutoEncoderDataset, self).__init__()
        self.data_path = data_path
        self.annotation = open(annotation, 
                                        'r').read().splitlines()
        self.batch_size = int(batch_size)

        self.preprocessing = ImagePreProcessing()

        self.index = 0

    def __len__(self):
        return len(self.annotation)

    def __getitem__(self, index:int):
        
        i=0
        while True:
            i+=1
            if index+i >= len(self.annotation):
                index = 0
            video_path = self.annotation[index+i]
            video_path = os.path.join(self.data_path, video_path) 
            
            cap = cv2.VideoCapture(video_path.strip())     
            count: int = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
            
            if not (cap.isOpened() and cap.get(cv2.CAP_PROP_FRAME_COUNT) > 0):
                continue
            
            if count < self.batch_size:
                ret_frames = np.random.randint(0, count, count)
            else:
                ret_frames= np.random.randint(0, count, self.batch_size)
            
                
            frames = []
            original = []
            # Get random frame indexes for batch size
            for frame in ret_frames:
                cap.set(1, frame)
                ret, frame = cap.read()
                if ret:
                    frame = np.transpose(frame, (2, 0, 1))
                    frame = self.preprocessing.transforms(torch.from_numpy(frame))
                    frame = self.preprocessing.preprocess(frame)
                    framex = self.preprocessing.augumentation(frame)
                    frames.append(framex)
                    framey = self.preprocessing.improve(frame)
                    original.append(framey)


            X = torch.stack(frames, dim=0)
            y = torch.stack(frames, dim=0)
            X = X
            y = X
            
            if(torch.isnan(X).any() or torch.isnan(y).any()):
                print("reported")
                continue
            else:
                break

        return X, y
    
class AutoEncoderDataModule(pl.LightningDataModule):
    def __init__(self, batch_size:int, num_workers:int,
                    data_path, annotation) -> None:
        super(AutoEncoderDataModule, self).__init__()
        self.annotation = annotation
        self.batch_size = int(batch_size)
        self.num_workers = int(num_workers)
        self.data_path = data_path

    def setup(self, stage=None):
        full_dataset = AutoEncoderDataset(self.batch_size,
                                           self.data_path, self.annotation)
        train_size = int(0.9 * len(full_dataset))
        val_size = int(0.025 * len(full_dataset))
        test_size = len(full_dataset) - train_size - val_size
        self.train_dataset, self.val_dataset, self.test_dataset = random_split(
            full_dataset, [train_size, val_size, test_size])
        
    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=1, num_workers=self.num_workers,
                           shuffle=True, drop_last=True, pin_memory=True) 

    def val_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=1, num_workers=self.num_workers,
                           shuffle=False, drop_last=True, pin_memory=True)

    def test_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=1, num_workers=self.num_workers,
                           shuffle=False, drop_last=True, pin_memory=True)

Pytorch Trainer config:

[AUTOENCODER_TRAIN]
max_epochs = 100
min_epochs = 50
#accelerator = gpu
benchmark = True
weights_summary = top
precision = 16
auto_lr_find = True
auto_scale_batch_size = True
auto_select_gpus = True
check_val_every_n_epoch = 1 
accumulate_grad_batches=64
fast_dev_run = False
enable_progress_bar = True
limit_val_batches=200
limit_train_batches=0.1
detect_anomaly=False

[AUTOENCODER_DATASET]
batch_size = 8
data_path = /mnt/nfs_share/nfs_share/Data/data/
num_workers = 12