Get derivatives from your network

Hello! I have this (simplified version) code:

grads = {}
def save_grad(name):
    def hook(grad):
        grads[name] = grad
    return hook

class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = nn.Linear(2, 1,  bias=False)
        self.linear2 = nn.Linear(1, 2,  bias=False)
            
    def forward(self, x):
        z = self.linear1(x)
        y_pred = self.linear2(z)

        return y_pred, z

for epoch in range(1000):
    model.train()
    for i, dt in enumerate(data.trn_dl):
        optimizer.zero_grad()
        output = model(dt[0])
        
        output[1].register_hook(save_grad('z_x_hat'))
        output[0][0].backward(retain_graph=True)
        z_x_hat = grads['z_x_hat']
        
        output[1].register_hook(save_grad('z_y_hat'))
        output[0][1].backward(retain_graph=True)
        z_y_hat = grads['z_y_hat']
        
        z_x_hat.requires_grad = True
        z_y_hat.requires_grad = True
                                     
        loss = abs(torch.sqrt(z_x_hat**2+z_y_hat**2)-1)
        print(z_x_hat,z_y_hat,loss)
        loss.backward()
        
        optimizer.step()

So my network has 2 outputs (x_hat, y_hat) and I want the loss function to be given by the derivative of the outputs with respect to the value of z (I need this as part of something more complex for the actual project, but this is the part where I am stuck). In this simple case, the derivatives are just the value of the weights from z to x_hat and z to y_hat (in the real case is the product of all the partial derivatives from x_hat to z). Using the code above:

output[1].register_hook(save_grad('z_x_hat'))
output[0][0].backward(retain_graph=True)
z_x_hat = grads['z_x_hat']

I am able to get the value of this derivative (and hence the weight), but the problem is that I need to use retain_graph=True in order to do that, and that screws up my whole backprop, as I want to optimize the network using backprop only on the loss at the end, not on these intermediate calls. Can someone help me with this? Thank you!