Suppose the original loss of the model is loss_ori
I want to add something to the loss function about the weight of every Conv2d layer just like:
before training:
if isinstance(m, nn.Conv2d):
I, O, W, H = m.weight.shape
U = torch.randn(I, O, W, H)
m.register_buffer("U", U)
re_loss = torch.FloatTensor(0)
m.register_buffer("re_loss",re_loss)
when training:
def admm_loss(m):
loss = nn.MSELoss()
if isinstance(m,(nn.Conv2d,nn.Linear)):
m.re_loss = loss(m.weight,m.U)
net.apply(admm_loss)
for name, para in net.state_dict().items():
if "re_loss" in name:
re_loss += para
loss = loss_ori + 1.5e-4 * re_loss
loss.backward()
but the program doesn’t work as excepted
i wonder if my change to the loss function is right?(regardless of U,just focus on if the grad of weigts involve both loss_ori and re_loss)
it seems both loss_ori and re_loss do not converge