Similar prediction results in multi-labeled image classification

I have a CV model built on transfer learning that appears to go through the training and validation process properly but ends up providing similar predictions for all the test set data. Looking around various forums, it seems such an issue can happen but it seems in most cases the context is not similar to mine.

ML Task: Multi-label image classification where an image can belong to multiple classes
I have implemented a custom data loader class that provides an image and labels where the labels are multi-hot encoded e.g [0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 1. 0.] for the classes.

I wonder what I could be doing wrong, here is my training I changed the fc layer of the model and trained it. I wonder if the model is not learning or is it overfitting? I have played around around with multiple hyperparameters but no change of outcomes. I would be happy to receive help.

def train_and_val_model(train_loader, val_loader, class_names):
    """
    Train a multi-label model 
    """
    lr_rate = 1e-7
    num_epochs = 30

    # Initialize model
    model = models.resnet50(pretrained=True) # Resnet50

    # Freeze all layers except the final fully connected layer
    for param in model.parameters():
        param.requires_grad = False

    # Replace the final fully connected layer
    # ResNet50
    num_classes = len(class_names)
    model.fc = torch.nn.Sequential(
        torch.nn.Linear(in_features=2048, out_features=1024),
        torch.nn.ReLU(),
        torch.nn.Dropout(p=0.3),
        torch.nn.Linear(in_features=1024, out_features=num_classes),
        torch.nn.Sigmoid())

    # Define loss function and optimizer
    criterion = torch.nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr_rate)

    # Train model
    model.to(device)
    for epoch in range(num_epochs):
        train_loss = 0
        val_loss = 0
        train_acc = 0
        val_acc = 0
        model.train()
        for i, batch in enumerate(train_loader):
            images = batch['images'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(images)
            outputs =  outputs.logits # inception specific
            # calculate loss
            loss = criterion(outputs, labels)

            optimizer.zero_grad()
            loss.backward()

            # update model parameters
            optimizer.step()

            # Calculate accuracy
            acc = accuracy(outputs, labels)
            train_acc += acc
            train_loss += loss.item()

        # Validation
        model.eval()
        with torch.no_grad():
            for batch_idx, batch in enumerate(val_loader):
                images = batch['images'].to(device)
                labels = batch['labels'].to(device)
                outputs = model(images)
                acc = accuracy(outputs, labels)
                loss = criterion(outputs, labels)
                val_acc += acc
                val_loss += loss.item()

        # Display results
        print('Epoch [{}/{}], Train Loss: {:.4f}, Val Loss: {:.4f}, Train_acc: {:.4f}, Val acc: {:.4f}'
              .format(epoch+1, num_epochs,
                      train_loss/len(train_loader), val_loss/len(val_loader),
                      train_acc/len(train_loader), val_acc/len(val_loader)))
    return model

My custom data classes:

class CustomDataset(data.Dataset):
    """
    A custom dataset class create a training and validation dataloaders

    # Image sizes to be changed depending with model resnet and vgg size 128x128 inception 299x299
    """
    def __init__(self, root_dir):
        self.root_dir = root_dir
        self.filename_to_class = {}
        self.classname_to_filenames = {}
        self.classnames = set()
        self.transform = transforms.Compose([
            transforms.Resize((299, 299)), 
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # mean=[116.022, 106.491, 95.719],std=[75.824, 72.377, 74.867]
        ])

        # Create a dictionary mapping classnames to the list of image filenames
        for filename in os.listdir(os.path.join(self.root_dir, "annotations")):
            if filename.endswith(".txt"):
                class_name = os.path.splitext(filename)[0]
                with open(os.path.join(self.root_dir, "annotations", filename)) as f:
                          image_numbers = f.readlines()
                          image_filenames = ["im{}.jpg".format(n.strip()) for n in image_numbers]
                          self.classname_to_filenames[class_name] = image_filenames
                          self.classnames.add(class_name)

        # Create a dictionary with multi labels
        for class_name in self.classnames: #os.listdir(os.path.join(self.root_dir, "annotations")):
            with open(os.path.join(self.root_dir, "annotations", class_name + '.txt'), "r") as f:
                images = f.readlines()
                images = [int(x.strip()) for x in images]
            for image in images:
                labels = np.zeros(len(self.classnames))
                image_filename = "im{}.jpg".format(image)
                # check if image is in a class and store the label of that class
                for i, class_name in enumerate(self.classnames):
                    if image_filename in self.classname_to_filenames[class_name]:
                        labels[i] = 1
                self.filename_to_class[image_filename] = labels
        print(f'show a sample labels {self.filename_to_class[list(self.filename_to_class.keys())[0]]}')

    def __len__(self):
        """The size of the dataset"""
        return len(self.filename_to_class)

    def __getitem__(self, index):
        """Get a specific image and label

        Read the corresponding image and convert it into a PyTorch
        Tensor

        """
        filename = list(self.filename_to_class.keys())[index]
        image_path = os.path.join(self.root_dir, "images", filename)
        image = Image.open(image_path).convert('RGB')
        label = self.filename_to_class[filename]

        # Apply transformations
        image = self.transform(image)

        # Create a dictionary containing the image and the label
        return {'images': image, 'labels': label}

class TestDataset(data.Dataset):
    """
    Create a dataset for testing where we dont have the ground truth
    """
    def __init__(self, root_dir):
        self.root_dir = root_dir
        self.transform = transforms.Compose([
            transforms.Resize((299, 299)), # 128, 128
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) #mean=[116.022, 106.491, 95.719],std=[75.824, 72.377, 74.867]
        self.images = os.listdir(os.path.join(root_dir, "images"))

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        # Load image
        filename = self.images[index]
        image_path = os.path.join(self.root_dir, "images", filename)
        image = Image.open(image_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return {'images':image, 'filenames':filename}

Your implementation looks quite good, but there might be a few issues that could lead to similar predictions for all the test set data. Here are a few things you could try to improve your model’s performance:

  1. Use BCEWithLogitsLoss correctly: You have applied a Sigmoid activation function in your final layer, and you are using BCEWithLogitsLoss as your criterion. BCEWithLogitsLoss combines the sigmoid activation and the binary cross entropy loss in a numerically stable way. So, you should remove the Sigmoid activation from the final layer of your model, like this:
model.fc = torch.nn.Sequential(
    torch.nn.Linear(in_features=2048, out_features=1024),
    torch.nn.ReLU(),
    torch.nn.Dropout(p=0.3),
    torch.nn.Linear(in_features=1024, out_features=num_classes))
  1. Unfreeze some layers: You are freezing all layers except the final fully connected layer. In some cases, it might be better to fine-tune some of the earlier layers as well. You can try unfreezing a few layers before the final layer to allow the model to adjust to your specific dataset.
# Unfreeze last few layers
for name, param in model.named_parameters():
    if "layer4" in name or "layer3" in name:
        param.requires_grad = True
  1. Learning rate: You have set a very low learning rate of 1e-7. Depending on the pre-trained weights and your dataset, it might be too small to make meaningful updates. You can try using a higher learning rate, such as 1e-4 or 1e-5, and see if it improves the model’s performance.

Thank you for pointing out these issues.

  1. Removing that extra sigmoid solved the problem of similar predictions across all testset data. Seems it was the main problem

  2. Unfreezing a few layers seems to fix overfitting which was manifesting by having a validation loss that was lower than training loss.

  3. I can further test the effect these hyperparameters