FISTA Optimizer Implementation for Neural Networks with Sparse Regularization

I’m implementing a FISTA (Fast Iterative Shrinkage-Thresholding Algorithm) optimizer in PyTorch for training neural networks with sparse regularization. My implementation doesn’t seem to be working as expected, and I’m seeking advice on how to properly handle different sets of parameters.

I’ve created a custom FISTA optimizer that follows the standard backtracking line search approach:

class FISTA(Optimizer):
    """FISTA optimizer with backtracking line search (one iteration per step call).

    Solves: min_x F(x)=f(x)+g(x), where
      - f: smooth term (provides gradient via closure)
      - g: non-smooth term via prox_func

    Each step() call performs exactly one outer FISTA iteration including line search.
    """
    def __init__(
        self,
        params,
        prox_func,
        lr=0.1,
        lr_decay=0.5,
        max_line_search=20,
        use_acceleration=True
    ):
        defaults = dict(
            prox_func=prox_func,
            lr=lr,
            lr_decay=lr_decay,
            max_line_search=max_line_search,
            use_acceleration=use_acceleration
        )
        super().__init__(params, defaults)

        # initialize state per parameter
        for group in self.param_groups:
            for p in group['params']:
                st = self.state[p]
                st['x_prev'] = p.data.clone()
                st['y'] = p.data.clone()
                st['tk'] = 1.0
                st['grad'] = torch.zeros_like(p.data)
                st['lr'] = group['lr']


    def step(self, closure):
        """Perform one FISTA iteration (with backtracking)"""
        if closure is None:
            raise ValueError("FISTA requires a closure that returns loss and calls backward().")

        # evaluate f at momentum point and get gradients
        for group in self.param_groups:
            for p in group['params']:
                p.data.copy_(self.state[p]['y'])

        fy = closure(backward=True)

        for group in self.param_groups:
            for p in group['params']:
                self.state[p]['grad'].copy_(p.grad)

        for group in self.param_groups:
            prox = group['prox_func']
            lr_decay = group['lr_decay']
            max_ls = group['max_line_search']
            use_acc = group['use_acceleration']

            # single outer iteration over all params
            for p in group['params']:
                grad = self.state[p]['grad']
                x_prev = self.state[p]['x_prev']
                y = self.state[p]['y']
                tk = self.state[p]['tk']
                lr = self.state[p]['lr']

                # backtracking line search for this param
                current_lr = lr
                ls = 0
                while True:
                    # gradient step and prox
                    v = y - current_lr * grad
                    ply = prox.apply(v, current_lr)

                    # evaluate f at prox point
                    with torch.no_grad():
                        p.data.copy_(ply)
                    fply = closure(backward=False)

                    # Q(beta, y) = f(y) + <beta-y, ∇f(y)> + (1/2*lr)||beta-y||_2^2 + g(beta)
                    diff = ply - y
                    Q2 = torch.dot(grad.view(-1), diff.view(-1))
                    Q3 = (1/(2*current_lr)) * torch.dot(diff.view(-1), diff.view(-1))
                    Q = fy + Q2 + Q3

                    if fply <= Q:
                        break
                    elif ls >= max_ls:
                        with torch.no_grad():
                            p.data.copy_(y)
                        break

                    current_lr *= lr_decay
                    ls += 1

                self.state[p]['lr'] = current_lr

                # End of line search
                # -------------------------------------------------

                # -------------------------------------------------
                # FISTA update step
                if use_acc:
                    tkp = 0.5 * (1.0 + sqrt(1.0 + 4.0 * tk * tk))
                    momentum = p.data + ((tk - 1.0) / tkp) * (p.data - x_prev)
                    self.state[p]['x_prev'].copy_(p.data)
                    self.state[p]['y'] = momentum
                    self.state[p]['tk'] = tkp
                else:
                    self.state[p]['x_prev'].copy_(p.data)
                    self.state[p]['y'] = p.data.clone()
                    self.state[p]['tk'] = 1.0

I’m trying to train a neural network with the following objective function:

Loss = MSE(y_pred, y_true) + λ * ||θ||₁

Where:

  • MSE is the mean squared error
  • λ is the regularization coefficient
  • θ is a subset of the network parameters (specifically the first layer weights and all biases)

From my testing, the optimizer doesn’t seem to be working as expected. I have a few specific concerns:

  1. The current implementation performs line search separately for each parameter, which seems inefficient and potentially problematic
  2. Parameters are updated immediately after their individual line search instead of all together
  3. I’m not sure if the approach scales well for neural networks with many parameters

Specific Questions:

  1. Parameter Grouping: Should I flatten all regularized parameters into a single tensor for FISTA optimization? Or keep them separate?

  2. Mixed Regularization: What’s the best way to handle the fact that I want to apply L1 regularization only to specific parameters (first layer weights and all biases) but not others?

  3. Line Search Efficiency: The current implementation evaluates the model once per parameter during line search, which could be very inefficient for large networks. How can I optimize this?

  4. Global vs. Local Learning Rate: Should I use a single learning rate for all parameters, or allow parameter-specific rates?

  5. Convergence Issues: Are there any known issues with FISTA convergence in neural network training that might be affecting my implementation?

For reference, here’s my proximal operator for L1 regularization:

class L1Prox:
    def __init__(self, lambda_):
        self.lambda_ = lambda_
        
    def apply(self, x, gamma):
        """Proximal operator for L1 norm: prox_{γλ||·||₁}(x) = soft_threshold(x, γλ)"""
        threshold = gamma * self.lambda_
        return torch.sign(x) * torch.clamp(torch.abs(x) - threshold, min=0)

Any insights, code improvements, or alternative approaches would be greatly appreciated!