Why is my Pytorch Custom Loss slow

This loss did help my model learn to assign the similar tag value for different parts of the same instance but assign distinctly different tag value for all parts from any two different instances.

the formula is written as follows:
The reference tag value for each instance is denoted as h. Assume that each instance has K parts( where K is not a fixed number, each image has N instances. n denotes each instance and n’ denotes each instance except for nth instance. L_g is the loss.

The main challenge is the N and the K are not fixed.

Assume that we have a feature map, denoted as fm. Its shape is [channels, height, width] when we only consider one batch. What we need to do is to pull some value on this feature map as close as possible while pushing some value on this feature map as far as possible. So I use a mask to denote which pixel belongs to the same instance.