Minimum Loss Over A Column

Suppose I have a tensor output, which is the result of my network and a tensor, ‘target’, which is my desired output:

output.shape = torch.Size([B, C, N])
target.shape = torch.Size([B, C, N])

I am interested that the network predicts a given N correctly, and not the particular permutation of C that is given by the network as output, as this cannot be ordered.

For this reason, I would like to calculate the loss for each possible permutation of C in output and input respectively, taking the minimum possible overall loss.

To demonstrate in normal code what I want to do, it would be written in normal Python script as follows:

import torch

def Loss(target, output):
    loss = 0
   #Calculate minimum MSE and add to loss value
    for b in range(target.shape[0]):
        for c_i in range(target.shape[1]):
             for c_ii in range(target.shape[1]):
                  loss_temp = torch.sum(target[b, c_i] - output[b,c_ii])**2)
                  if(c_ii == 0 or loss_temp < min_loss):
                     min_loss = loss_temp
              loss = loss + min_loss

   #Calculate mean over batches
   loss = loss/target.shape[0]

   return loss

Is there a more elegant, PyTorch-oriented way of perform this operation?

I have made an attempt, with the following method:

  • utilize itertools.permutations to index one of the tensors;

  • repeat the other tensor along a new axis;

  • calculate sum of square difference along last axes

  • find minimum along the new, permutation axis

  • take average over batch.

def Loss (target, input):
        #calculate indices
        idx = torch.from_numpy(np.array(list(itertools.permutations(range(output.shape[-2])))))
        #index tensor with indices
        input_perms = input[:, idx, :]

        #repeat other tensor to same length
        target_perms = target.unsqueeze(1).repeat(1,len(idx),1,1)

        #calculate sum of squares
        losses = (input_perms - target_perms)**2
        losses = losses.flatten(start_dim = -2)
        loss_len = losses.shape[-1]
        losses = torch.sum(losses, dim=-1)
        #calculate minimum along permutation axis and then mean along batch axis. Remember, we need still to divide through by the number of entries in each of the parts onf the tensor we summed!
        min_loss = torch.mean(losses.min(dim=-1, keepdim=True)[0])/loss_len
        return min_loss