Loss does not decrease for binary classification

I am trying to implement binary classification. I have 100K (3 channel, 224 x 224px pre-resized) image dataset that I am trying to train the model for if picture is safe for work or not. I am data engineer with statistician background so I am working on the model like last 5-10 days. I have read many answers from ptrblck and tried to implement the solution based on suggestions but unfortunately loss didn’t decrease.

Here is the class implemented by using PyTorch Lightning,

from .dataset import CloudDataset
from .split import DatasetSplit
from pytorch_lightning import LightningModule
from pytorch_lightning.metrics import Accuracy
from torch import stack
from torch.nn import BCEWithLogitsLoss, Conv2d, Dropout, Linear, MaxPool2d, ReLU
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate
from torchvision.transforms import ToTensor
from util import logger
from util.config import config


class ClassifyModel(LightningModule):
    def __init__(self):
        super(ClassifyModel, self).__init__()

        # custom dataset split class
        ds = DatasetSplit(config.s3.bucket, config.train.ratio)

        # split records for train, validation and test
        self._train_itr, self._valid_itr, self._test_itr = ds.split()

        self.conv1 = Conv2d(3, 32, 3, padding=1)
        self.conv2 = Conv2d(32, 64, 3, padding=1)
        self.conv3 = Conv2d(64, 64, 3, padding=1)

        self.pool = MaxPool2d(2, 2)

        self.fc1 = Linear(7 * 28 * 64, 512)
        self.fc2 = Linear(512, 16)
        self.fc3 = Linear(16, 4)
        self.fc4 = Linear(4, 1)

        self.dropout = Dropout(0.25)

        self.relu = ReLU(inplace=True)

        self.accuracy = Accuracy()

    def forward(self, x):
        # comments are shape before execution
        # [32, 3, 224, 224]
        x = self.pool(self.relu(self.conv1(x)))
        # [32, 32, 112, 112]
        x = self.pool(self.relu(self.conv2(x)))
        # [32, 64, 56, 56]
        x = self.pool(self.relu(self.conv3(x)))
        # [32, 64, 28, 28]
        x = self.pool(self.relu(self.conv3(x)))
        # [32, 64, 14, 14]
        x = self.dropout(x)

        # [32, 64, 14, 14]
        x = x.view(-1, 7 * 28 * 64)

        # [32, 12544]
        x = self.relu(self.fc1(x))
        # [32, 512]
        x = self.relu(self.fc2(x))
        # [32, 16]
        x = self.relu(self.fc3(x))
        # [32, 4]
        x = self.dropout(self.fc4(x))

        # [32, 1]
        x = x.squeeze(1)
        # [32]
        return x

    def configure_optimizers(self):
        return Adam(self.parameters(), lr=0.001)

    def training_step(self, batch, batch_idx):
        image, target = batch
        target = target.float()

        output = self.forward(image)

        loss = BCEWithLogitsLoss()
        output = loss(output, target)

        logits = self(image)
        self.accuracy(logits, target)

        return {'loss': output}

    def validation_step(self, batch, batch_idx):
        image, target = batch
        target = target.float()

        output = self.forward(image)

        loss = BCEWithLogitsLoss()
        output = loss(output, target)

        return {'val_loss': output}

    def collate_fn(self, batch):
        batch = list(filter(lambda x: x is not None, batch))
        return default_collate(batch)

    def train_dataloader(self):
        transform = ToTensor()
        workers = 0 if config.train.test else config.train.workers

        cds = CloudDataset(config.s3.bucket, self._train_itr, transform)

        return DataLoader(
            dataset=cds,
            batch_size=32,
            shuffle=True,
            num_workers=workers,
            collate_fn=self.collate_fn,
        )

    def val_dataloader(self):
        transform = ToTensor()
        workers = 0 if config.train.test else config.train.workers

        cds = CloudDataset(config.s3.bucket, self._valid_itr, transform)

        return DataLoader(
            dataset=cds,
            batch_size=32,
            num_workers=workers,
            collate_fn=self.collate_fn,
        )

    def test_dataloader(self):
        transform = ToTensor()
        workers = 0 if config.train.test else config.train.workers

        cds = CloudDataset(config.s3.bucket, self._test_itr, transform)

        return DataLoader(
            dataset=cds,
            batch_size=32,
            shuffle=True,
            num_workers=workers,
            collate_fn=self.collate_fn,
        )

    def validation_epoch_end(self, outputs):
        avg_loss = stack([x['val_loss'] for x in outputs]).mean()

        logger.info(f'Validation loss is {avg_loss}')

    def training_epoch_end(self, outs):
        accuracy = self.accuracy.compute()

        logger.info(f'Training accuracy is {accuracy}')

Here is the custom log output,

epoch 0
Validation loss is 0.5988735556602478
Training accuracy is 0.4441356360912323

epoch 1
Validation loss is 0.6406065225601196
Training accuracy is 0.4441356360912323

epoch 2
Validation loss is 0.621654748916626
Training accuracy is 0.443579763174057

epoch 3
Validation loss is 0.5089989304542542
Training accuracy is 0.4580322504043579

epoch 4
Validation loss is 0.5484663248062134
Training accuracy is 0.4886047840118408

epoch 5
Validation loss is 0.5552918314933777
Training accuracy is 0.6142301559448242

epoch 6
Validation loss is 0.661466121673584
Training accuracy is 0.625903308391571

The last squeeze() operation is most likely not needed (x = x.squeeze(1)), but might be alright if your target has the same shape ([32]).

Could you try to overfit a small data samples, e.g. just 10 samples and see if your model is able to do so by playing around with some hyperparameters.
I think Lightning also ships with a functionality in newer versions, which does exactly this.

When I remove x = x.squeeze(1) from forward the loss function throws an exception just like,

ValueError: Target size (torch.Size([32])) must be the same as input size (torch.Size([32, 1]))

I have added batch normalization, weight initialization and updated layers so problem solved. thank you @ptrblck