Multi-pair losses

Hi I have a question about calculating losses for multiple-pairs.

As far as I know, Pytorch support calculating losses only between input and target.
In my case, I have several embedding pairs (i.e A:(N,D),B:(N,D),C:(N,D)) and want to pull each other by using like L2 loss.
Is there any internal function that can handle multiple pairs of inputs to calculate losses like the above case in Pytorch?

If not, what is the right way to pull multiple pairs of inputs?

I think there are two cases:

input:A ,target:B
input:A ,target:C
input:B ,target:A
input:B ,target:C
input:C ,target:A
input:C ,target:B
input:A ,target:B
input:A ,target:C
input:B ,target:C

cf) Reason why I think of two cases, is because I wonder whether the target embeddings also get gradients or not

I’m not sure I understand the use case completely, but you can manually calculate the loss using al (differentiable) PyTorch operations and don’t necessarily need to stick to the pre-defined loss functions defined in the nn namespace.