Questions about create a new loss function

(Njuww) #1

Suppose the original loss of the model is loss_ori
I want to add something to the loss function about the weight of every Conv2d layer just like:

before training:

if isinstance(m, nn.Conv2d):
    I, O, W, H = m.weight.shape
    U = torch.randn(I, O, W, H)
    m.register_buffer("U", U)
    re_loss = torch.FloatTensor(0)
    m.register_buffer("re_loss",re_loss)

when training:

 def admm_loss(m):
      loss = nn.MSELoss()
      if isinstance(m,(nn.Conv2d,nn.Linear)):
          m.re_loss = loss(m.weight,m.U)
 net.apply(admm_loss)
 for name, para in net.state_dict().items():
       if "re_loss" in name:
           re_loss += para
 loss = loss_ori + 1.5e-4 * re_loss
 loss.backward()


but the program doesn’t work as excepted
i wonder if my change to the loss function is right?(regardless of U,just focus on if the grad of weigts involve both loss_ori and re_loss)
it seems both loss_ori and re_loss do not converge

(balamurali) #2

What are you trying to do ? Can you explain it with better ?

(Njuww) #3

trying to update weigths by model_loss and a sparse matrix U, U will be updated by myself ,i just want to know if the way i compute the grads is right