How would you implement this framework in pytorch?

Hi I am a newbie to pytorch and DL, and I am trying to implement this paper
Paper, with epochs, batch loaders and SGD.

Is this the correct approach ?

opt_centers = optim.SGD(centers,lr)
loss = 0
u = None
for e in epochs : 
 while not ( C_cconverged and U_converged):
      for batch in loader: 
           opt.zerograd()
           u = update_U(X)
           loss = 
           loss.backward(retain_graph= True)
           loss_F += loss 
           opt_centers.step
           
           if C < treshold and U < treshold:
               C_cconverged = U_converged = True
               break
 while not ( W_converged):
      for batch in loader:
           #  proximal operator
           W_prox= update_W()
           G_loss +=  G_loss_function(W_prox)
           if W_prox < treshold:
               C_cconverged = U_converged = True
               break
 if W_converged :
      for batch in loader: 
           opt.zerograd()
           u = update_U(X)
           loss = 
           loss.backward(retain_graph= True)
           loss_F += loss 
           opt_centers.step
 
      return model.centers,u

Hi @EMYR,

There are a few mistakes in your code,

opt.zerograd() #should be opt_centers.zero_grad()
opt_centers.step #should be opt_centers.step()