I am trying to reproduce the loss function from Eq 9 in this paper. The idea is that it goes over different matchings of elements in the output and target to create a permutation invariant version of Cross Entropy Loss. This is my implementation currently
CrossEntropyLoss = nn.CrossEntropyLoss()
def SetCrossEntropyLoss(output, target):
loss = 0
for i in range(output.shape):
se = 0
for j in range(output.shape):
H = CrossEntropyLoss(output[i].reshape(1,output.shape), target[j].reshape(1))
se += torch.exp(-H)
I have two questions, though.
Would the loops and in place operations mess with the gradients being computed correctly?
Any ideas on how one would make this function more elegant using pytorch operations?