Accuracy(and other metrics) in multi-label edge segmentation

Hi,
I am relatively new to PyTorch and at the moment I am working on edge segmentation with CASENet. Their idea is that a pixel can belong to more than one class at the same time. I am using the pytorch implementation of CASENet provided by DFF , on my custom dataset consisting of 3 classes.

I am trying to calculate the accuracy per class in this setting since I cannot monitor the model’s performance only based on the loss.

Here’s what I have tried so far:

    def training(self, epoch):
        self.model.train()
        tbar = tqdm(self.trainloader)
        # tbar = self.trainloader
        train_loss = 0.
        train_loss_all = 0.

        acc_class0 = 0.
        acc_class1 = 0.
        acc_class2 = 0.
        acc_class0_all = 0.
        acc_class1_all = 0.
        acc_class2_all = 0.
        
        correct_class0 = 0.
        correct_class1 = 0.
        correct_class2 = 0.
        correct_class0_all = 0.
        correct_class1_all = 0.
        correct_class2_all = 0.
        
        total = 0.
        total_all = 0.
        
        for i, (image, target) in enumerate(tbar):
            self.scheduler(self.optimizer, i, epoch)
            self.optimizer.zero_grad()
            if torch_ver == "0.3":
                image = Variable(image)
                target = Variable(target)

            image = image.cuda()
            target = target.cuda()

            outputs = self.model(image.float())

            loss = self.criterion(outputs, target)
            loss.backward()

            self.optimizer.step()

            train_loss += loss.item()
            train_loss_all += loss.item()

            total += target[0][1,:,:].nelement() * target.size(0)
            total_all += target[0][1,:,:].nelement() * target.size(0)

            correct_0, correct_1, correct_2 = self.get_accuracy_stats(outputs[1], target)
            correct_class0 += correct_0
            correct_class1 += correct_1
            correct_class2 += correct_2
            
            correct_class0_all += correct_0
            correct_class1_all += correct_1
            correct_class2_all += correct_2

            if i == 0 or (i+1) % 20 == 0:
                train_loss = train_loss / min(20, i + 1)
                acc_class0 = 100 * correct_class0 / total
                acc_class1 = 100 * correct_class1 / total
                acc_class2 = 100 * correct_class2 / total

                # self.logger.info('Epoch [%d], Batch [%d],\t train-loss: %.4f' % (
                #        epoch + 1, i + 1, train_loss))
                self.logger.info('Epoch [%d], Batch [%d],\t train-loss: %.4f,\t class1 acc: %.3f,\t class2 acc: %.3f,\t class3 acc: %.3f' % (
                    epoch + 1, i + 1, train_loss, acc_class0, acc_class1, acc_class2))
                train_loss = 0.
                total = 0.
                correct_class0 = 0.
                correct_class1 = 0.
                correct_class2 = 0.

        acc_class0_all = 100 * correct_class0_all / total_all
        acc_class1_all = 100 * correct_class1_all / total_all
        acc_class2_all = 100 * correct_class2_all / total_all

        #self.logger.info('-> Epoch [%d], Train epoch loss: %.3f' % (
        #                epoch + 1, train_loss_all / (i + 1)))
        self.logger.info('-> Epoch [%d], Train epoch loss: %.3f,\t class1 epoch acc: %.3f,\t class2 epoch acc: %.3f,\t class3 epoch acc: %.3f' % (
                         epoch + 1, train_loss_all / (i + 1), acc_class0_all, acc_class1_all, acc_class2_all))

        if not self.args.no_val:
            # save checkpoint every 20 epoch
            filename = "checkpoint_%s.pth.tar"%(epoch+1)
            if epoch % 19 == 0 or epoch == args.epochs-1:
                utils.save_checkpoint({
                    'epoch': epoch + 1,
                    'state_dict': self.model.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    }, self.args, filename)

        self.viz.line(win=self.loss_plot, name='train_loss', update='append', X=np.array([epoch+1]),
                 Y=np.array([train_loss_all / (i + 1)]))

    def get_accuracy_stats(self, predicted, targets):
        correct_class_outputs = [[] for x in range(predicted[0].size(0))]
    
        for i in range(targets.size(0)):
            predictions = predicted[i]
            target = targets[i]
            pad_mask = target[0, :, :]
            target = target[1:, :, :]
            target_nopad = torch.mul(target, pad_mask)
            predictions_nopad = torch.mul(predictions, pad_mask)
    
            for j in range(predictions_nopad.size(0)):
                class_pred = predictions_nopad[j]
                class_target = target_nopad[j]
    
                class_output = torch.sigmoid(class_pred)
                class_output = torch.round(class_output)
                correct_class_outputs[j].append((class_output == class_target).sum().item())
    
        return sum(correct_class_outputs[0]), sum(correct_class_outputs[1]), sum(correct_class_outputs[2])

Some background:

In the line:

outputs = self.model(image.float())

my output logits is a tuple of the form:([6, 3, 512, 512], [6, 3, 512, 512]) → The first tuple item is the side5 output of the model and the 2nd item is the model’s fused layer output. 3 is the edge maps belonging to their respective classes, if I understand this correctly. I am only considering the 2nd tuple item for calculation of the accuracy as I am not interested in the side5 output.

My target is of the form: [6, 4, 512, 512] where target[i][0,:,:] is a padded mask which is excluded from the loss calculation and target[i][1:,:,:] is the actual target.

I would like to know if my code for accuracy calculation per class is correct and am I on the right path? Or is there an easier way to do this? I feel that I have missed out something as I get strange outputs for accuracy always outputting above 90% for all the batches.

Here is a sample output:

2022-02-23 16:52:39:
=>Epoches 0, learning rate = 0.0100
2022-02-23 16:52:42: Epoch [1], Batch [1], train-loss: 0.8800, class1 acc: 3.627, class2 acc: 99.756, class3 acc: 3.417
1%| | 19/1540 [00:18<19:35, 1.29it/s]2022-02-23 16:52:55: Epoch [1], Batch [20], train-loss: 0.6791, class1 acc: 80.341, class2 acc: 99.837, class3 acc: 60.657
3%|▎ | 39/1540 [00:32<18:19, 1.37it/s]2022-02-23 16:53:10: Epoch [1], Batch [40], train-loss: 0.4224, class1 acc: 94.660, class2 acc: 98.015, class3 acc: 97.362
4%|▍ | 59/1540 [00:47<17:41, 1.39it/s]2022-02-23 16:53:25: Epoch [1], Batch [60], train-loss: 0.3861, class1 acc: 94.057, class2 acc: 97.624, class3 acc: 95.716
5%|▌ | 79/1540 [01:02<17:57, 1.36it/s]2022-02-23 16:53:40: Epoch [1], Batch [80], train-loss: 0.3689, class1 acc: 94.118, class2 acc: 97.754, class3 acc: 95.444
6%|▋ | 99/1540 [01:16<17:08, 1.40it/s]2022-02-23 16:53:54: Epoch [1], Batch [100], train-loss: 0.3326, class1 acc: 92.769, class2 acc: 98.245, class3 acc: 94.348
8%|▊ | 119/1540 [01:31<16:53, 1.40it/s]2022-02-23 16:54:08: Epoch [1], Batch [120], train-loss: 0.3048, class1 acc: 93.446, class2 acc: 97.996, class3 acc: 95.004
9%|▉ | 139/1540 [01:45<16:42, 1.40it/s]2022-02-23 16:54:23: Epoch [1], Batch [140], train-loss: 0.3570, class1 acc: 93.433, class2 acc: 98.012, class3 acc: 94.916
10%|█ | 159/1540 [01:59<16:27, 1.40it/s]2022-02-23 16:54:37: Epoch [1], Batch [160], train-loss: 0.3236, class1 acc: 92.567, class2 acc: 98.614, class3 acc: 94.335
12%|█▏ | 179/1540 [02:14<17:10, 1.32it/s]2022-02-23 16:54:52: Epoch [1], Batch [180], train-loss: 0.3651, class1 acc: 92.132, class2 acc: 97.685, class3 acc: 93.697
13%|█▎ | 199/1540 [02:30<17:42, 1.26it/s]2022-02-23 16:55:08: Epoch [1], Batch [200], train-loss: 0.3106, class1 acc: 93.269, class2 acc: 98.169, class3 acc: 94.849
14%|█▍ | 219/1540 [02:46<17:33, 1.25it/s]2022-02-23 16:55:24: Epoch [1], Batch [220], train-loss: 0.3878, class1 acc: 91.229, class2 acc: 97.875, class3 acc: 93.294
16%|█▌ | 239/1540 [03:02<17:49, 1.22it/s]2022-02-23 16:55:40: Epoch [1], Batch [240], train-loss: 0.3361, class1 acc: 93.295, class2 acc: 98.609, class3 acc: 94.954
17%|█▋ | 259/1540 [03:18<16:29, 1.29it/s]2022-02-23 16:55:56: Epoch [1], Batch [260], train-loss: 0.3350, class1 acc: 94.269, class2 acc: 98.122, class3 acc: 95.307
18%|█▊ | 279/1540 [03:33<15:11, 1.38it/s]2022-02-23 16:56:11: Epoch [1], Batch [280], train-loss: 0.3167, class1 acc: 93.818, class2 acc: 98.189, class3 acc: 94.971
19%|█▉ | 299/1540 [03:49<17:16, 1.20it/s]2022-02-23 16:56:27: Epoch [1], Batch [300], train-loss: 0.3230, class1 acc: 91.987, class2 acc: 98.382, class3 acc: 93.885

Also, how would I go about calculating other metrics such as F1 (during training) as that would be a better indicator of the model’s performance? Your help will be highly appreciated! :slight_smile:

@ptrblck Could you please help me? Any kind of hints/tips will be very helpful.

The accuracy calculation looks reasonable based on your explanation.
You could apply a threshold instead of using torch.round, but round should also work.
Since you are concerned about the high accuracy, print the predictions of a single batch, calculate the accuracy manually, and compare it to the calculation.

@ptrblck Thank you so much for your feedback and your time! I also used a confusion matrix for a batch of predictions and it still points to the same results. What do you mean by manually calculating the accuracy?

You could print the predictions of a single batch and compare them to the targets of the current batch by counting the “right” vs. “wrong” predictions. You could then compare the manually calculated accuracy against the reported one and check if your metric calculation is wrong.

Thank you for your help!