Hi I am a newbie to pytorch and DL, and I am trying to implement this paper
Paper, with epochs, batch loaders and SGD.
Is this the correct approach ?
opt_centers = optim.SGD(centers,lr)
loss = 0
u = None
for e in epochs :
while not ( C_cconverged and U_converged):
for batch in loader:
opt.zerograd()
u = update_U(X)
loss =
loss.backward(retain_graph= True)
loss_F += loss
opt_centers.step
if C < treshold and U < treshold:
C_cconverged = U_converged = True
break
while not ( W_converged):
for batch in loader:
# proximal operator
W_prox= update_W()
G_loss += G_loss_function(W_prox)
if W_prox < treshold:
C_cconverged = U_converged = True
break
if W_converged :
for batch in loader:
opt.zerograd()
u = update_U(X)
loss =
loss.backward(retain_graph= True)
loss_F += loss
opt_centers.step
return model.centers,u