Autograd in backward pass with multi-variables

The Deep Equilibrium Model as presented in Chapter 4: Deep Equilibrium Models trys to learn a layer that computing the fixed point x* = f(x*) for a given function f. The forward pass can be implemented as:

def forward_iteration(f, x0, max_iter=50, tol=1e-2):
    f0 = f(x0)
    res = []
    for k in range(max_iter):
        x = f0
        f0 = f(x)
        res.append((f0 - x).norm().item() / (1e-5 + f0.norm().item()))
        if (res[-1] < tol):
            break
    return f0, res

According to the tutorial http://implicit-layers-tutorial.org/deep_equilibrium_models/ , using the implicit differentiation, the backward pass is given as

class DEQFixedPoint(nn.Module):
    def __init__(self, f, solver, **kwargs):
        super().__init__()
        self.f = f
        self.solver = solver
        self.kwargs = kwargs
        
    def forward(self, x):
        # compute forward pass and re-engage autograd tape
        with torch.no_grad():
            z, self.forward_res = self.solver(lambda z : self.f(z, x), torch.zeros_like(x), **self.kwargs)
        z = self.f(z,x)
        
        # set up Jacobian vector product (without additional forward calls)
        z0 = z.clone().detach().requires_grad_()
        f0 = self.f(z0,x)
        def backward_hook(grad):
            g, self.backward_res = self.solver(lambda y : autograd.grad(f0, z0, y, retain_graph=True)[0] + grad,
                                               grad, **self.kwargs)
            return g
                
        z.register_hook(backward_hook)
        return z

Now I would like to extend the above idea to a multi-variable function, i.e. (x*, y*) = f(x*, y*). The forward pass is simply updated to be

def forward_iteration(f, x0, y0, max_iter=50, tol=1e-2):
    fx0 = f(x0)
    fy0 = f(y0)
    res = []
    for k in range(max_iter):
        x = fx0
        y = fy0
        fx0 = f(x)
        fy0 = f(y)
        res_x.append((fx0 - x0).norm().item() / (1e-5 + fx0.norm().item()))
        res_y.append((fy0 - y0).norm().item() / (1e-5 + fy0.norm().item()))
        res = res_x + res_y
        if (res[-1] < tol):
            break
    return fx0, fy0, res

I got struggled in implementing the autograd in the the backward pass. Here is my code:

class DEQFixedPoint(nn.Module):
    def __init__(self, f, solver, **kwargs):
        super().__init__()
        self.f = f
        self.solver = solver
        self.kwargs = kwargs
        
    def forward(self, x, y):
        # compute forward pass and re-engage autograd tape
        with torch.no_grad():
            zx, zy, self.forward_res_x, self.forward_res_y = self.solver(lambda x, y : self.f(x, y), torch.zeros_like(x), torch.zeros_like(y),**self.kwargs)
        zx, zy = self.f(x,y)
        
        # set up Jacobian vector product (without additional forward calls)
        zx0 = zx.clone().detach().requires_grad_()
        zy0 = zy.clone().detach().requires_grad_()
        fx0 = self.f(zx0,zy0)
        def backward_hook(grad):
            %%  This part is not implemnted yet %%%
            return gx, gy
                
        zx.register_hook(backward_hook)
        zy.register_hook(backward_hook)
        return zx, zy

Could anyone help me on how to implement the autograd function in the backward_hook while two variables x and y are considered (given x and y are in different size)?