I have written some simple code here where I eventually do an unconstrained optimization of a loss function l. I’m only doing a single backward on the loss each loop before I do zero_grad, so I don’t get why this fails if I don’t have retain_graph=True.
I want to scale this template up to something a lot more involved, and it would be important to get rid of the retain_graph=True so that the optimization doesn’t slow down.
torch.autograd.set_detect_anomaly(True)
class Optimizer():
def __init__(self):
#Initialize stuff
self.X, self.Y, self.Z = 1, 2, 3
self.phi1 = 1 # TRAIN PHI1
self.phi2 = 2 # TRAIN PHI2
self.mlp_xz = 3 # TRAIN MLP_XZ
self.basis = self._get_basis()
def _get_basis(self):
return 4
def _initialize_params(self):
# Initialize theta | N mlp
x = torch.rand(1, requires_grad=True)
y = torch.rand(1, requires_grad=True)
self.params = x, y
return x, y
def _get_tensor(self, constrs):
"""A helper function to get a tensor out of an array of
tensors which have gradients"""
torch_constrs = torch.zeros(len(constrs))
for i in range(len(constrs)):
torch_constrs[i] = constrs[i]
return torch_constrs
def _get_constraints(self, x, y):
# Construct the constraints
constrs = []
constrs.append(x + y - 1)
return self._get_tensor(constrs)
def _get_objective(self, x, y):
# get the objective
return y**2 + x**2
def _lagrangian(self, obj, constr, lmbda, tau, sign):
# Construct the unconstrained lagrangian from the constraints, objective and lambda values
psi = - lmbda * constr + 0.5 * tau * constr**2
psisum = torch.sum(psi)
lag = sign*obj + psisum
return(lag)
def _update_lambda(self, constr, lmbda, tau):
return lmbda - tau * constr
def optimize(self):
x, y = self._initialize_params()
obj = self._get_objective(x, y)
constr = self._get_constraints(x, y)
lmbda = torch.ones(len(constr))
tau=1
sign = 1
optimizer = optim.SGD([x, y], lr=0.005)
for i in range(10):
for i in range(30):
l = self._lagrangian(obj, constr, lmbda, tau, sign)
optimizer.zero_grad()
l.backward(retain_graph=True)
optimizer.step()
obj = self._get_objective(x, y)
constr = self._get_constraints(x, y)
lmbda = self._update_lambda(constr, lmbda, tau)
print(x, y, obj, constr)
return(x, y)
g = Optimizer()
g.optimize()
Thank you for your help!