Hi, I have a custom loss function, and wish to update weights of the center?
and the other problem is when i write loss,centers = center_loss() I get error that loss function is not iteratble.
see the following code
class center_loss(torch.nn.Module):
def __init__(self,num_class=10,num_dim=2,num_batch=100):
super(center_loss,self).__init__()
self.num_class = num_class
self.num_dim = num_dim
self.num_batch = num_batch
self.centers = torch.nn.Parameter(torch.randn(num_class,num_dim).cuda())
def forward(self,features,labels):
index = labels.type(torch.cuda.FloatTensor) # casting the variables to int
features = features.type(torch.cuda.FloatTensor)
#calculating new centers
var4 = self.centers.unsqueeze(1)
var5 = var4.transpose(0,1)
ar6 = var5.repeat(10,1,1)
labels = one_hot_embedding(index,10)
embed2 = labels.unsqueeze(1).view(self.num_batch,10,1)
embed3 = embed2.repeat(1,1,2)
embed3 = embed3.type(torch.FloatTensor)
var6 =var6.type(torch.FloatTensor)
a = torch.mm(var6[:,:,0],embed3[:,:,0].t())
b = torch.mm(var6[:,:,1],embed3[:,:,1].t())
aa = a.t()
bb = b.t()
c = torch.zeros(self.num_batch,2)
c[:,0] = aa[:,1]
c[:,1] = bb[:,1]
diff = c - features.type(torch.FloatTensor)
imm = torch.zeros(self.num_batch,2)
imm.index_add_(0, index, diff)
unique, counts = np.unique(index.numpy(), return_counts=True)
s = torch.Tensor(self.num_batch).view(10,1)
centers_update = self.centers - (imm/s)*0.5
# calculating loss
x_2 = torch.pow(features,2)
c_2 = torch.pow(self.centers,2)
x_2_s = torch.sum(x_2,1)
x_2_s = x_2_s.view(self.num_batch,1)
c_2_s = torch.sum(c_2,1)
c_2_s = c_2_s.view(self.num_class,1)
x_2_s_e = x_2_s.repeat(1,self.num_class)
c_2_s_e = c_2_s.t().repeat(self.num_batch,1)
xc = 2*torch.mm(features,center.t())
# we want only positive values,
dis = x_2_s_e + c_2_s_e - xc
di = dis.type(torch.FloatTensor)
di = torch.sqrt(torch.clamp(di, min=0))
# since center loss focuses on intra distances, we are not concerened about the distance that we calculated from
#other centers, we will use the other centers to increase inter loss.
bl = labels.type(torch.ByteTensor)
dii = torch.masked_select(di,bl)
center_loss = dii/self.num_batch
return center_loss, self.centers
I want he self.centers weight to be update by centers_update value