Data augmentation reduces model credibility scores

I use torchvision faster_rcnn model

    def create_model(pretrained=False, num_classes=2):
           model = tv.models.detection.fasterrcnn_resnet50_fpn(
                num_classes=num_classes,
                pretrained_backbone=pretrained,
                box_detections_per_img=500 

            )
            return model

Which is trained by method

    def train(self, transform=False,
              learning_rate=None, pretrained=False, optimizer="adam",
              credibility_score=0.6):
        """
            Train over dataset
        """
        # Load data
        dataset, testset = self.get_datasets(transform)
        dataloader = DataLoader(
            dataset=dataset,
            batch_size=self.BATCH_SIZE,
            shuffle=True,
            pin_memory="cuda" in self.device.type,
            num_workers=self.NUM_WORKERS,
            collate_fn=collate
        )
        # Create model
        model = self.create_model(
            pretrained=pretrained
        )
        model = model.to(self.device)
        model.train()

        # Pass parametres which need optimizing to the optimizer
        params = [p for p in model.parameters() if p.requires_grad]
        optimizer = self.create_optimizer(
            params, learning_rate=learning_rate, optim_type=optimizer)

        # Main train loop
        for epoch in range(self.epochs):
            # Reset loss
            epoch_loss = .0

            # Redo with data augment for X times
            for _ in range(self.REDO_WITH_AUGMENT if transform else 1):

                # Iterate over the dataloader
                for index, (img, targets) in enumerate(dataloader):
                    img = [item.to(self.device) for item in img]
                    for target in targets:
                        for key in target:
                            target[key] = target[key].to(self.device)

                    optimizer.zero_grad()

                    # Evaluate
                    output = model(img, list(targets))
                    # Sum losses of the output
                    # Contains 4 types of loss
                    # loss_classifier
                    # loss_box_reg
                    # loss_objectiveness
                    # loss_rpn_box_reg
                    losses = sum(output.values())

                    # Backpropagate
                    losses.backward()
                    optimizer.step()

                    epoch_loss += float(losses.item())

            model.eval()
            
            rank, recall, precision, f1 = evaluate_testset(
               model, testset, self.device, credibility_score=credibility_score)
            model.train()

        # Save the resulting model to path
        torch.save(model.state_dict(), self.MODEL_SAVE_PATH)

Dataset __getitem__

    def __getitem__(self, index: int) -> Tuple[torch.FloatTensor, dict]:
        """
            Returns:
                source (tensor): Tensorof size self.__size
                target (dict): Dictionary with keys
                - boxes (FloatTensor[N, 4]): The ground-truth boxes
                 [x1, y1, x2, y2]
                - labels (Int64Tensor[N]): The class label of each image
        """
        image, targets = self.data[index]
        if self.transform:
            _, xmax, ymax = image.shape
            device = targets["boxes"].device
            # Horizontal flip
            if random.random() > .5:
                image = tf.vflip(image)
                # Change boxes coordinates
                targets["boxes"] = torch.abs(
                    torch.tensor(
                        [0, ymax, 0, ymax],
                        device=device
                    ) - targets["boxes"]
                )
                # Make y1 < y2
                targets["boxes"][:, [1, 3]] = targets["boxes"][:, [3, 1]]

            # Vertical flip
            if random.random() > .5:
                image = tf.hflip(image)
                targets["boxes"] = torch.abs(
                    torch.tensor([xmax, 0, xmax, 0], device=device) - targets["boxes"]
                )
                # Change the columns to make x1 < x2
                targets["boxes"][:, [0, 2]] = targets["boxes"][:, [2, 0]]

            # Color Jitter
            if random.random() > .5:
                image = tv.transforms.ColorJitter()(image)

        return image, targets

but what ends up happening is that the results from non-augmented dataset seem to outperform the ones done with the transform. (The transform parameter having value makes random chance of 50% augmentations of 90 and 180 degree flips and color jitter changes.)

The return scores from augmented data models are far worse than the ones obtained by the non-transformed trained ones. I have triple checked the transform and it seems ok. (The boxes and the images seem to be correct). As the test function is indifferent to this method, it would seem logical that the way I do augmentation is wrong for PyTorch. Is there something that I am doing wrong here?

If you are using data augmentation the number of epochs might increase to reach the same performance but it should also allow to train even longer and thus to reduce the loss further.
Depending on the images you are using, the augmentation might also be too aggressive.
E.g. I don’t think a vertical flip performs well out of the box using “natural” images and the model would need to learn these new samples, which most likely won’t occur in the dataset.

Hi!

These images could be described as crops of geography maps. They really have no “gravity”, up/down or left/right in that sense. I would have assumed then that the flipping would have been optimal augmentation for my images. We moved on as the results from the unaugmented dataset were satisfactory. I tried running 50 epochs but the results were still worse.

I had a discussion with professor and it seems that “credibility score” is somewhat artificial limit for classification task and if it would be necessary to improve the results, I should implement optimizing algorithm for the testset to find the score to best perform in the testset.

Thank you for your time, I truly appreciate your efforts on the forum.

I have found the error in the code above. The self.data is defined here
as a List[Tensor, Dict[str, Tensor]] in the dataset class. The changes done to targets["boxes"] changed the underlying data in self.data list which meant that any epochs after the first could contain “double” changes. The boxes would be changed twice and be incorrect in the subsequent epochs.

I changed the code to

    def __getitem__(self, index: int) -> Tuple[torch.FloatTensor, dict]:
        """
            Returns:
                source (tensor): Tensorof size self.__size
                target (dict): Dictionary with keys
                - boxes (FloatTensor[N, 4]): The ground-truth boxes
                 [x1, y1, x2, y2]
                - labels (Int64Tensor[N]): The class label of each image
        """
        image, targets = self.data[index]
        # Force copies
        image = image.clone().detach()
        targets = dict(targets)
        if self.transform:
            _, xmax, ymax = image.shape
            device = targets["boxes"].device
            # Horizontal flip
            if random.random() > .5:
                image = tf.vflip(image)
                # Change boxes coordinates
                targets["boxes"] = torch.abs(
                    torch.tensor(
                        [0, ymax, 0, ymax],
                        device=device
                    ) - targets["boxes"]
                )
                # Make y1 < y2
                targets["boxes"][:, [1, 3]] = targets["boxes"][:, [3, 1]]

            # Vertical flip
            if random.random() > .5:
                image = tf.hflip(image)
                targets["boxes"] = torch.abs(
                    torch.tensor([xmax, 0, xmax, 0], device=device) - targets["boxes"]
                )
                # Change the columns to make x1 < x2
                targets["boxes"][:, [0, 2]] = targets["boxes"][:, [2, 0]]

            # Color Jitter
            if random.random() > .5:
                image = tv.transforms.ColorJitter()(image)

        return image, targets

This seems to have fixed the issue. I am still running the training. I came back to this for completely unrelated reasons.