How to unroll ISTA as a neural network

Hello,

I’m trying to unroll a proximal gradient descent/ISTA algorithm as a NN.
The classical ISTA iteration can be written as :

 x(k+1) = sign(x(k)) * soft_threshold( W1 * x(k) + W2* y)

where W1 and W2 are known matrices.

I want to unroll the optimization steps as a NN where W1 and W2 are learnable linear layers.

This is my attempt to do so

class Lista(torch.nn.Module):

    def __init__(self, uvw, freq, npixel, n_layers, max_iter=100, verbose=False, device=None, *args,**kwargs):
        
        super().__init__()

        self.H        =  torch.tensor(forward_operator(uvw, freq, npixel))
        self.n_layers = n_layers
        self.device   =  device
        self.nvis, self.npixel2 = self.H.shape
        self.L        =        np.linalg.norm(self.H,ord=2)**2      

        self.layers_Wx  = nn.ModuleList()
        self.layers_Wy = nn.ModuleList()

      
        for _ in range(self.n_layers):
           self.layers_Wx.append(torch.nn.Linear(self.npixel2,self.npixel2, bias=False))
           self.layers_Wy.append(torch.nn.Linear(self.nvis,self.npixel2, bias=False))  #.to(torch.cdouble)

        self._init_layer_parameters()

     
    def forward(self, y, lmbd, x0):
        x_   = check_tensor(x0, device=self.device)
        step = 1/self.L
        for layer_idx in range(self.n_layers):

            z_ =  self.layers_Wx[layer_idx](x_) + self.layers_Wy[layer_idx](y)
            x_ = soft_thresholding(z_, lmbd*stp)
        
        return x_



    def _init_layer_parameters(self):
        Wx_init = torch.from_numpy( np.eye(self.npixel2) - np.matmul(self.H.T.conj(), self.H)/self.L )
        Wy_init = torch.from_numpy( self.H.T.conj()/self.L)

        for layer_idx in range(self.n_layers):
            self.layers_Wx[layer_idx].weight = nn.Parameter(Wx_init)
            self.layers_Wy[layer_idx].weight = nn.Parameter(Wy_init)

Everything functions correctly when executing the code without backpropagation (i.e., solely as ISTA without updating W1 and W2).
However, the results deteriorate when running it with backpropagation (no convergences, loss explodes etc).

I would greatly appreciate any insights or guidance on this issue.
Thanks