The Deep Equilibrium Model as presented in Chapter 4: Deep Equilibrium Models trys to learn a layer that computing the fixed point x* = f(x*) for a given function f. The forward pass can be implemented as:
def forward_iteration(f, x0, max_iter=50, tol=1e-2):
f0 = f(x0)
res = []
for k in range(max_iter):
x = f0
f0 = f(x)
res.append((f0 - x).norm().item() / (1e-5 + f0.norm().item()))
if (res[-1] < tol):
break
return f0, res
According to the tutorial http://implicit-layers-tutorial.org/deep_equilibrium_models/ , using the implicit differentiation, the backward pass is given as
class DEQFixedPoint(nn.Module):
def __init__(self, f, solver, **kwargs):
super().__init__()
self.f = f
self.solver = solver
self.kwargs = kwargs
def forward(self, x):
# compute forward pass and re-engage autograd tape
with torch.no_grad():
z, self.forward_res = self.solver(lambda z : self.f(z, x), torch.zeros_like(x), **self.kwargs)
z = self.f(z,x)
# set up Jacobian vector product (without additional forward calls)
z0 = z.clone().detach().requires_grad_()
f0 = self.f(z0,x)
def backward_hook(grad):
g, self.backward_res = self.solver(lambda y : autograd.grad(f0, z0, y, retain_graph=True)[0] + grad,
grad, **self.kwargs)
return g
z.register_hook(backward_hook)
return z
Now I would like to extend the above idea to a multi-variable function, i.e. (x*, y*) = f(x*, y*). The forward pass is simply updated to be
def forward_iteration(f, x0, y0, max_iter=50, tol=1e-2):
fx0 = f(x0)
fy0 = f(y0)
res = []
for k in range(max_iter):
x = fx0
y = fy0
fx0 = f(x)
fy0 = f(y)
res_x.append((fx0 - x0).norm().item() / (1e-5 + fx0.norm().item()))
res_y.append((fy0 - y0).norm().item() / (1e-5 + fy0.norm().item()))
res = res_x + res_y
if (res[-1] < tol):
break
return fx0, fy0, res
I got struggled in implementing the autograd in the the backward pass. Here is my code:
class DEQFixedPoint(nn.Module):
def __init__(self, f, solver, **kwargs):
super().__init__()
self.f = f
self.solver = solver
self.kwargs = kwargs
def forward(self, x, y):
# compute forward pass and re-engage autograd tape
with torch.no_grad():
zx, zy, self.forward_res_x, self.forward_res_y = self.solver(lambda x, y : self.f(x, y), torch.zeros_like(x), torch.zeros_like(y),**self.kwargs)
zx, zy = self.f(x,y)
# set up Jacobian vector product (without additional forward calls)
zx0 = zx.clone().detach().requires_grad_()
zy0 = zy.clone().detach().requires_grad_()
fx0 = self.f(zx0,zy0)
def backward_hook(grad):
%% This part is not implemnted yet %%%
return gx, gy
zx.register_hook(backward_hook)
zy.register_hook(backward_hook)
return zx, zy
Could anyone help me on how to implement the autograd function in the backward_hook while two variables x and y are considered (given x and y are in different size)?