I’m dealing with a loss function that requires computation of pairwise inner product of samples of each class in my dataset. Since my classes have different number of samples I cannot store all my data in one tensor and use pytorch native operations to compute loss. Therefore I have to loop over class labels and compute the loss for each of the classes separately before averaging like below:
def loss(output, label): unq_idx, rev_idx = torch.unique(labels, return_inverse=True) cdot = torch.zeros_like(unq_idx) for i in range(len(unq_idx)): output_i = output[rev_idx == i] cdot[i] = output_i.matmul(output_i.transpose(-1, -2)).mean((-1, -2)) return cdot.mean()
The loss function is computed on the GPU however the speed is much lower compared to when I use native pytorch operations (no for loop) to calculate the same loss for a balanced dataset (equal size and average class size). I was wondering if there is any pytorch function (or combination of them) that allows me to comprehend my loss tensor with each element requiring slices of output with different lengths.