suppose i have a dict named ‘admm’,contains admm[name] = [U,Z]
try to update weigths by:
for name, param in net.named_parameters():
if param.requires_grad and 'weigth' in name:
wei = param.clone()
U,Z = admm_dict[name]
admm_loss += rho / 2 * torch.norm(wei - Z + U, 2)
return admm_loss
loss = admm_loss
loss.backward()
it said i should use loss.backward(retain_graph = True),but then it will become"CUDA OUT OF MEMORY",how can i do loss.backward() without retain_graph = True?