For loop alternative in GPU tensor comprehension

Hi
I’m dealing with a loss function that requires computation of pairwise inner product of samples of each class in my dataset. Since my classes have different number of samples I cannot store all my data in one tensor and use pytorch native operations to compute loss. Therefore I have to loop over class labels and compute the loss for each of the classes separately before averaging like below:

def loss(output, label):
    unq_idx, rev_idx = torch.unique(labels, return_inverse=True)
    cdot = torch.zeros_like(unq_idx)
    for i in range(len(unq_idx)):
        output_i = output[rev_idx == i]
        cdot[i] = output_i.matmul(output_i.transpose(-1, -2)).mean((-1, -2))
    return cdot.mean()

The loss function is computed on the GPU however the speed is much lower compared to when I use native pytorch operations (no for loop) to calculate the same loss for a balanced dataset (equal size and average class size). I was wondering if there is any pytorch function (or combination of them) that allows me to comprehend my loss tensor with each element requiring slices of output with different lengths.