Why my confusion matrix goes very slow?

Recently, I have been doing something about semantic segmentation, but my evaluation goes very slowly. I checked my code and found out that my confusion matrix computes too slowly:

torch.cuda.synchronize()
t2 = time.time()
conf_m.update(label.flatten(), outs.argmax(1).flatten())
torch.cuda.synchronize()
print(time.time() - t2)

This takes almost 20 seconds to run. My confusion matrix is a simplified version of torchvision’s implementation:

class ConfusionMatrix(object):
    def __init__(self, num_classes):
        self.num_classes = num_classes
        self.mat = None

    def update(self, a, b):
        n = self.num_classes
        if self.mat is None:
            self.mat = torch.zeros((n, n), dtype=torch.int64, device=a.device)
        with torch.no_grad():
            k = (a >= 0) & (a < n)
            inds = n * a[k].to(torch.int64) + b[k]
            self.mat += torch.bincount(inds, minlength=n ** 2).reshape(n, n)

    def reset(self):
        if self.mat is not None:
            self.mat.zero_()

    def compute(self, with_background=False):
        h = self.mat.float()
        acc_global = torch.diag(h).sum() / h.sum()
        acc = torch.diag(h) / h.sum(1)
        iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h))
        if with_background is False:
            iu = iu[1:]
        recall = torch.diag(h) / h.sum(0)
        f1 = 2 * acc * recall / (acc + recall)
        return acc_global, acc, iu, f1

Is there a problem with my confusion matrix? Or something else I don’t know. My experiment setup:
GPU: RTX 4090
CPU: AMD EPYC 9654 96-Core Processor
Dataset: MS COCO
torch: 2.0.0+cu118

update: It’s weird. Sometimes it takes 20 seconds, and sometimes it takes 5 seconds.
update: I send label and outs to cpu before conf_mat.update(). The time consumption now is 1.5 seconds. I believe that the problem is about CUDA, but I don’t know where.