Hello,
I’m trying to unroll a proximal gradient descent/ISTA algorithm as a NN.
The classical ISTA iteration can be written as :
x(k+1) = sign(x(k)) * soft_threshold( W1 * x(k) + W2* y)
where W1 and W2 are known matrices.
I want to unroll the optimization steps as a NN where W1 and W2 are learnable linear layers.
This is my attempt to do so
class Lista(torch.nn.Module):
def __init__(self, uvw, freq, npixel, n_layers, max_iter=100, verbose=False, device=None, *args,**kwargs):
super().__init__()
self.H = torch.tensor(forward_operator(uvw, freq, npixel))
self.n_layers = n_layers
self.device = device
self.nvis, self.npixel2 = self.H.shape
self.L = np.linalg.norm(self.H,ord=2)**2
self.layers_Wx = nn.ModuleList()
self.layers_Wy = nn.ModuleList()
for _ in range(self.n_layers):
self.layers_Wx.append(torch.nn.Linear(self.npixel2,self.npixel2, bias=False))
self.layers_Wy.append(torch.nn.Linear(self.nvis,self.npixel2, bias=False)) #.to(torch.cdouble)
self._init_layer_parameters()
def forward(self, y, lmbd, x0):
x_ = check_tensor(x0, device=self.device)
step = 1/self.L
for layer_idx in range(self.n_layers):
z_ = self.layers_Wx[layer_idx](x_) + self.layers_Wy[layer_idx](y)
x_ = soft_thresholding(z_, lmbd*stp)
return x_
def _init_layer_parameters(self):
Wx_init = torch.from_numpy( np.eye(self.npixel2) - np.matmul(self.H.T.conj(), self.H)/self.L )
Wy_init = torch.from_numpy( self.H.T.conj()/self.L)
for layer_idx in range(self.n_layers):
self.layers_Wx[layer_idx].weight = nn.Parameter(Wx_init)
self.layers_Wy[layer_idx].weight = nn.Parameter(Wy_init)
Everything functions correctly when executing the code without backpropagation (i.e., solely as ISTA without updating W1 and W2).
However, the results deteriorate when running it with backpropagation (no convergences, loss explodes etc).
I would greatly appreciate any insights or guidance on this issue.
Thanks