# Why does this code need retain_graph=True despite only one .backward() in each loop?

I have written some simple code here where I eventually do an unconstrained optimization of a loss function l. I’m only doing a single backward on the loss each loop before I do zero_grad, so I don’t get why this fails if I don’t have retain_graph=True.

I want to scale this template up to something a lot more involved, and it would be important to get rid of the retain_graph=True so that the optimization doesn’t slow down.

``````torch.autograd.set_detect_anomaly(True)

class Optimizer():

def __init__(self):
#Initialize stuff
self.X, self.Y, self.Z = 1, 2, 3
self.phi1 = 1 # TRAIN PHI1
self.phi2 = 2 # TRAIN PHI2
self.mlp_xz = 3 # TRAIN MLP_XZ
self.basis = self._get_basis()

def _get_basis(self):
return 4

def _initialize_params(self):
# Initialize theta | N mlp

self.params = x, y

return x, y

def _get_tensor(self, constrs):
"""A helper function to get a tensor out of an array of
torch_constrs = torch.zeros(len(constrs))
for i in range(len(constrs)):
torch_constrs[i] = constrs[i]

def _get_constraints(self, x, y):
# Construct the constraints
constrs = []
constrs.append(x + y - 1)
return self._get_tensor(constrs)

def _get_objective(self, x, y):
# get the objective
return y**2 + x**2

def _lagrangian(self, obj, constr, lmbda, tau, sign):
# Construct the unconstrained lagrangian from the constraints, objective and lambda values
psi = - lmbda * constr + 0.5 * tau * constr**2
psisum = torch.sum(psi)

lag = sign*obj + psisum
return(lag)

def _update_lambda(self, constr, lmbda, tau):
return lmbda - tau * constr

def optimize(self):
x, y = self._initialize_params()
obj = self._get_objective(x, y)
constr = self._get_constraints(x, y)
lmbda = torch.ones(len(constr))
tau=1
sign = 1
optimizer = optim.SGD([x, y], lr=0.005)

for i in range(10):
for i in range(30):
l = self._lagrangian(obj, constr, lmbda, tau, sign)
l.backward(retain_graph=True)
optimizer.step()
obj = self._get_objective(x, y)
constr = self._get_constraints(x, y)
lmbda = self._update_lambda(constr, lmbda, tau)
print(x, y, obj, constr)

return(x, y)

g = Optimizer()
g.optimize()
``````

Looks like part of your graph: `obj = self._get_objective(x, y)` is only computed once and reused each iteration.