# calculating loss
cen = self.centers.clone()
feat = features.clone()
x_2 = torch.pow(feat,2)
c_2 = torch.pow(cen,2)
x_2_s = torch.sum(x_2,1)
x_2_s = x_2_s.view(self.num_batch,1)
c_2_s = torch.sum(c_2,1)
c_2_s = c_2_s.view(self.num_class,1)
temp_c = torch.mm(labels.cuda(),c_2_s)
x_2_s_e = x_2_s.repeat(1,self.num_class)
c_2_s_e = c_2_s.t().repeat(self.num_batch,1)
# c_2_s_e = temp_c
xc = 2*torch.mm(feat,cen.t())
# we want only positive values,
dis = x_2_s_e + c_2_s_e - xc
di = dis.type(torch.FloatTensor)
di2 = torch.sqrt(torch.clamp(di, min=0))
# since center loss focuses on intra distances, we are not concerened about the distance that we calculated from
#other centers, we will use the other centers to increase inter loss.
bl = labels.type(torch.ByteTensor)
dii = torch.masked_select(di2,bl)
center_loss = dii.sum()/self.num_batch
with torch.no_grad():
self.centers.copy_(centers_update)
return center_loss
This is how I am calculating my loss value