# Autograd in backward pass with multi-variables

The Deep Equilibrium Model as presented in Chapter 4: Deep Equilibrium Models trys to learn a layer that computing the fixed point x* = f(x*) for a given function f. The forward pass can be implemented as:

``````def forward_iteration(f, x0, max_iter=50, tol=1e-2):
f0 = f(x0)
res = []
for k in range(max_iter):
x = f0
f0 = f(x)
res.append((f0 - x).norm().item() / (1e-5 + f0.norm().item()))
if (res[-1] < tol):
break
return f0, res
``````

According to the tutorial http://implicit-layers-tutorial.org/deep_equilibrium_models/ , using the implicit differentiation, the backward pass is given as

``````class DEQFixedPoint(nn.Module):
def __init__(self, f, solver, **kwargs):
super().__init__()
self.f = f
self.solver = solver
self.kwargs = kwargs

def forward(self, x):
# compute forward pass and re-engage autograd tape
z, self.forward_res = self.solver(lambda z : self.f(z, x), torch.zeros_like(x), **self.kwargs)
z = self.f(z,x)

# set up Jacobian vector product (without additional forward calls)
f0 = self.f(z0,x)
return g

z.register_hook(backward_hook)
return z
``````

Now I would like to extend the above idea to a multi-variable function, i.e. (x*, y*) = f(x*, y*). The forward pass is simply updated to be

``````def forward_iteration(f, x0, y0, max_iter=50, tol=1e-2):
fx0 = f(x0)
fy0 = f(y0)
res = []
for k in range(max_iter):
x = fx0
y = fy0
fx0 = f(x)
fy0 = f(y)
res_x.append((fx0 - x0).norm().item() / (1e-5 + fx0.norm().item()))
res_y.append((fy0 - y0).norm().item() / (1e-5 + fy0.norm().item()))
res = res_x + res_y
if (res[-1] < tol):
break
return fx0, fy0, res
``````

I got struggled in implementing the autograd in the the backward pass. Here is my code:

``````class DEQFixedPoint(nn.Module):
def __init__(self, f, solver, **kwargs):
super().__init__()
self.f = f
self.solver = solver
self.kwargs = kwargs

def forward(self, x, y):
# compute forward pass and re-engage autograd tape
zx, zy, self.forward_res_x, self.forward_res_y = self.solver(lambda x, y : self.f(x, y), torch.zeros_like(x), torch.zeros_like(y),**self.kwargs)
zx, zy = self.f(x,y)

# set up Jacobian vector product (without additional forward calls)