Problems with gradient flow in a deep unfolded network

Hi everyone,
I am having some issues with gradient flow in a deep unfolded network being an analytical variation of LISTA (ALISTA), where basically in the following recursion
image

the only learnable parameters are the threshold in the soft-thresholding operator and the inner stepsize: W is determined by an optimization problem tailored in compressive sensing framework. So my model is implemented as follows:

class ALISTA(nn.Module):
    def __init__(self, A, K=5):
        super(ALISTA, self).__init__()

        # Number of layers <-> iterations
        self.K = K

        # Parameters
        self.A = A
        self.W = self.W_optimization() 

        self.beta = nn.Parameter(torch.ones(self.K + 1, 1, 1), requires_grad=True)
        self.mu = nn.Parameter(torch.ones(self.K + 1, 1, 1) / torch.linalg.norm(self.A.T @ self.A, 2), requires_grad=True)  

        self.W1 = self.W.T @ self.A
        self.W2 = self.W.T
        
        # Losses when doing inference
        self.losses = []

    def W_optimization(self):

        N, M = self.A.shape
        W = cp.Variable((N, M))

        objective = cp.Minimize(cp.norm(W.T @ self.A.cpu().numpy(), 'fro'))
        constraints = [W[:, m].T @ self.A.numpy()[:, m] == 1 for m in range(M)]
        prob = cp.Problem(objective, constraints)

        prob.solve()

        return torch.from_numpy(W.value).float()

    def _shrink(self, x, beta):
        return beta * F.softshrink(x / beta, lambd=1)

    def forward(self, y, S=None):
        # Initial estimation with shrinkage
        x = self._shrink(self.mu[0, :, :] * y @ self.W2.t(), self.beta[0, :, :])

        for i in range(1, self.K + 1):
            h = x - self.mu[i, :, :] * (x @ self.W1.t() - y @ self.W2.t())
            x = self._shrink(h, self.beta[i, :, :])

            # If ground truth is provided, calculate the loss for monitoring
            if S is not None:
                self.losses.append(F.mse_loss(x.detach(), S.detach(), reduction="sum").item())

        return x

The problem is that the gradient of the parameters is always 0 and I don’t get why, since the vanilla LISTA I implemented works and is designed very similarly to this. Thanks in advance for your help.