I am trying to reproduce the loss function from Eq 9 in this paper. The idea is that it goes over different matchings of elements in the output and target to create a permutation invariant version of Cross Entropy Loss. This is my implementation currently

CrossEntropyLoss = nn.CrossEntropyLoss()
def SetCrossEntropyLoss(output, target):
loss = 0
for i in range(output.shape[0]):
se = 0
for j in range(output.shape[0]):
H = CrossEntropyLoss(output[i].reshape(1,output.shape[1]), target[j].reshape(1))
se += torch.exp(-H)
loss+=torch.log(se)
return -loss

I have two questions, though.

Would the loops and in place operations mess with the gradients being computed correctly?

Any ideas on how one would make this function more elegant using pytorch operations?

Neither the loops nor the inplace operations will cause problems with
backpropagation and gradient computation.

Note, a nice feature of autograd is that it warns you if an inplace
operation has broken backpropagation. If that were to happen, you
could simply not use inplace operations, e.g.:

se = se + torch.exp(-H)
loss = loss + torch.log(se)

Even though these looks very similar to your inplace versions, they
create new tensors (at some cost in memory), so no inplace operations
occur.

It is easy enough to implement CrossEntropyLoss yourself, so it is
also easy to implement a loop-free version of SetCrossEntropyLoss
using pytorch tensor operations:

>>> import torch
>>> print (torch.__version__)
1.12.0
>>>
>>> _ = torch.manual_seed (2022)
>>>
>>> def SetCrossEntropyLoss(output, target):
... loss = 0
... for i in range(output.shape[0]):
... se = 0
... for j in range(output.shape[0]):
... H = torch.nn.CrossEntropyLoss() (output[i].reshape(1,output.shape[1]), target[j].reshape(1))
... se += torch.exp(-H)
... loss+=torch.log(se)
... return -loss
...
>>> output = torch.randn (10, 5)
>>> target = torch.randint (5, (10,))
>>>
>>> SetCrossEntropyLoss (output, target)
tensor(-6.9854)
>>>
>>> def SetCrossEntropyLossB (output, target):
... log_probs = torch.log_softmax (output, dim = 1)
... xEnts = -log_probs[:, target]
... loss = -(-xEnts).exp().sum (dim = 1).log().sum (dim = 0)
... return loss
...
>>> SetCrossEntropyLossB (output, target)
tensor(-6.9854)