More Classes => Slower Confusion Matrix (Bincount)

It seems like bincount performance is dramatically reduced when using a high number of bins i.e. number of classes squared.

When training on Mapillary V2 with ~115 classes, it takes anywhere between 80 and 400ms for bincounting on an image of only 256x512. Whereas when training on an higher resolution cityscapes with 19 classes, 1024x2048, so more pixels to bincount, it usually completes in less than 10ms.

I presume that bincounting is achieved with atomic add operations, where there should be less collisions with more classes so should be faster if anything. My thoughts are that this is because bincount uses on chip memory if it will fit on an SM, whereas if its too big it uses global memory instead so becomes super slow. If this is the case, any thoughts on improvement on speed? Maybe calculating the confusion matrix in chunks instead, to try to get it back on chip memory? Theoretically 115x115x4(int32) = 52.9kbytes < 99kbytes available on SM 8.6 threadblock. I tried casting input to int16 to save space but that didn’t change anything, bincount might cast it back to 32 or even 64 internally.

Most of my training loop when running mapillary is consumed by this bincount, so making it speedy would be of great benefit.

Pytorch: 1.8.1+cu111
GPU: RTX3090

Just some snippets of my code below for reference.


    def _gen_confusion_mat(self, prediction: torch.Tensor, target: torch.Tensor,
                           mask: torch.Tensor = None) -> torch.Tensor:
        if mask is not None:
            conf_mat = torch.bincount(self._n_classes * target[mask] + prediction[mask],
                                      minlength=self._n_classes**2)
        else:
            conf_mat = torch.bincount(self._n_classes * target + prediction,
                                      minlength=self._n_classes**2)
        return conf_mat.reshape(self._n_classes, self._n_classes)

def add_sample(self, predictions: Dict[str, torch.Tensor],
                   targets: Dict[str, torch.Tensor], loss: int=0, **kwargs) -> None:

        mask = targets['seg'] != 255
        torch.cuda.synchronize()
        s_time = time.time()

        for idx in range(preds.shape[0]):
            conf_mat = self._gen_confusion_mat(preds[idx], targets['seg'][idx], mask[idx])
            self.metric_data["Confusion_Mat"] += conf_mat

            torch.cuda.synchronize()
            print(f"Confusion mat gen: {1000*(time.time() - s_time):.2f} ms")
            s_time = time.time()

Yeah so if I chunk it down into slices of 999 with a mask each iteration is less than a millisecond, and after looping over all of them, it results in a factor of 10 reduction in overall time spent (max time 40ms, usually ~10ms).

I found a few magic numbers with SummaryOps.cu: THRESH_NUMBER_BINS_FOR_MULTI_BLOCK_MEM = 100 and THRESH_NUMBER_BINS_FOR_GLOBAL_MEM = 1000 which I’m not sure why is a thing when the proper checks for whether the bins will fit within shared memory are done anyway, and briefly looking at the kernel itself, I’m not sure why there’s only a limit of 100 for smem, I guess to use up all the smem I’d only be able to run one block at a time on an SM, I’m sure a cuda ninja could explain to me why this is the case.

Either way to help anyone else out, here’s my IP:

    def _gen_confusion_mat(self, prediction: torch.Tensor, target: torch.Tensor,
                           mask: torch.Tensor = None) -> torch.Tensor:
        if mask is not None:
            temp = self._n_classes * target[mask] + prediction[mask]

        else:
            temp = self._n_classes * target + prediction

        i = 0
        conf_mat = torch.zeros(self._n_classes**2, dtype=torch.int32, device=target.device)
        while i < self._n_classes**2:
            t_mask = temp >= i
            t_mask &= temp < i + 999
            minlength = self._n_classes**2-i if i + 999 > self._n_classes**2 else 999
            conf_mat[i:i+999] = torch.bincount(temp[t_mask] - i, minlength=minlength)
            i += 999

        return conf_mat.reshape(self._n_classes, self._n_classes)