It seems like bincount performance is dramatically reduced when using a high number of bins i.e. number of classes squared.
When training on Mapillary V2 with ~115 classes, it takes anywhere between 80 and 400ms for bincounting on an image of only 256x512. Whereas when training on an higher resolution cityscapes with 19 classes, 1024x2048, so more pixels to bincount, it usually completes in less than 10ms.
I presume that bincounting is achieved with atomic add operations, where there should be less collisions with more classes so should be faster if anything. My thoughts are that this is because bincount uses on chip memory if it will fit on an SM, whereas if its too big it uses global memory instead so becomes super slow. If this is the case, any thoughts on improvement on speed? Maybe calculating the confusion matrix in chunks instead, to try to get it back on chip memory? Theoretically 115x115x4(int32) = 52.9kbytes < 99kbytes available on SM 8.6 threadblock. I tried casting input to int16 to save space but that didn’t change anything, bincount might cast it back to 32 or even 64 internally.
Most of my training loop when running mapillary is consumed by this bincount, so making it speedy would be of great benefit.
Pytorch: 1.8.1+cu111
GPU: RTX3090
Just some snippets of my code below for reference.
def _gen_confusion_mat(self, prediction: torch.Tensor, target: torch.Tensor,
mask: torch.Tensor = None) -> torch.Tensor:
if mask is not None:
conf_mat = torch.bincount(self._n_classes * target[mask] + prediction[mask],
minlength=self._n_classes**2)
else:
conf_mat = torch.bincount(self._n_classes * target + prediction,
minlength=self._n_classes**2)
return conf_mat.reshape(self._n_classes, self._n_classes)
def add_sample(self, predictions: Dict[str, torch.Tensor],
targets: Dict[str, torch.Tensor], loss: int=0, **kwargs) -> None:
mask = targets['seg'] != 255
torch.cuda.synchronize()
s_time = time.time()
for idx in range(preds.shape[0]):
conf_mat = self._gen_confusion_mat(preds[idx], targets['seg'][idx], mask[idx])
self.metric_data["Confusion_Mat"] += conf_mat
torch.cuda.synchronize()
print(f"Confusion mat gen: {1000*(time.time() - s_time):.2f} ms")
s_time = time.time()