# Minimum Loss Over A Column

Suppose I have a tensor `output`, which is the result of my network and a tensor, ‘target’, which is my desired output:

``````output.shape = torch.Size([B, C, N])
target.shape = torch.Size([B, C, N])
``````

I am interested that the network predicts a given `N` correctly, and not the particular permutation of `C` that is given by the network as output, as this cannot be ordered.

For this reason, I would like to calculate the loss for each possible permutation of `C` in `output` and `input` respectively, taking the minimum possible overall loss.

To demonstrate in normal code what I want to do, it would be written in normal Python script as follows:

``````import torch

def Loss(target, output):
loss = 0
min_loss=0

#Calculate minimum MSE and add to loss value
for b in range(target.shape[0]):
for c_i in range(target.shape[1]):
for c_ii in range(target.shape[1]):
loss_temp = torch.sum(target[b, c_i] - output[b,c_ii])**2)
if(c_ii == 0 or loss_temp < min_loss):
min_loss = loss_temp
loss = loss + min_loss

#Calculate mean over batches
loss = loss/target.shape[0]

return loss
``````

Is there a more elegant, PyTorch-oriented way of perform this operation?

I have made an attempt, with the following method:

• utilize itertools.permutations to index one of the tensors;

• repeat the other tensor along a new axis;

• calculate sum of square difference along last axes

• find minimum along the new, permutation axis

• take average over batch.

``````def Loss (target, input):
#calculate indices
idx = torch.from_numpy(np.array(list(itertools.permutations(range(output.shape[-2])))))
#index tensor with indices
input_perms = input[:, idx, :]

#repeat other tensor to same length
target_perms = target.unsqueeze(1).repeat(1,len(idx),1,1)

#calculate sum of squares
losses = (input_perms - target_perms)**2
losses = losses.flatten(start_dim = -2)
loss_len = losses.shape[-1]
losses = torch.sum(losses, dim=-1)

#calculate minimum along permutation axis and then mean along batch axis. Remember, we need still to divide through by the number of entries in each of the parts onf the tensor we summed!
min_loss = torch.mean(losses.min(dim=-1, keepdim=True)[0])/loss_len
return min_loss
``````