I’m implementing a FISTA (Fast Iterative Shrinkage-Thresholding Algorithm) optimizer in PyTorch for training neural networks with sparse regularization. My implementation doesn’t seem to be working as expected, and I’m seeking advice on how to properly handle different sets of parameters.
I’ve created a custom FISTA optimizer that follows the standard backtracking line search approach:
class FISTA(Optimizer):
"""FISTA optimizer with backtracking line search (one iteration per step call).
Solves: min_x F(x)=f(x)+g(x), where
- f: smooth term (provides gradient via closure)
- g: non-smooth term via prox_func
Each step() call performs exactly one outer FISTA iteration including line search.
"""
def __init__(
self,
params,
prox_func,
lr=0.1,
lr_decay=0.5,
max_line_search=20,
use_acceleration=True
):
defaults = dict(
prox_func=prox_func,
lr=lr,
lr_decay=lr_decay,
max_line_search=max_line_search,
use_acceleration=use_acceleration
)
super().__init__(params, defaults)
# initialize state per parameter
for group in self.param_groups:
for p in group['params']:
st = self.state[p]
st['x_prev'] = p.data.clone()
st['y'] = p.data.clone()
st['tk'] = 1.0
st['grad'] = torch.zeros_like(p.data)
st['lr'] = group['lr']
def step(self, closure):
"""Perform one FISTA iteration (with backtracking)"""
if closure is None:
raise ValueError("FISTA requires a closure that returns loss and calls backward().")
# evaluate f at momentum point and get gradients
for group in self.param_groups:
for p in group['params']:
p.data.copy_(self.state[p]['y'])
fy = closure(backward=True)
for group in self.param_groups:
for p in group['params']:
self.state[p]['grad'].copy_(p.grad)
for group in self.param_groups:
prox = group['prox_func']
lr_decay = group['lr_decay']
max_ls = group['max_line_search']
use_acc = group['use_acceleration']
# single outer iteration over all params
for p in group['params']:
grad = self.state[p]['grad']
x_prev = self.state[p]['x_prev']
y = self.state[p]['y']
tk = self.state[p]['tk']
lr = self.state[p]['lr']
# backtracking line search for this param
current_lr = lr
ls = 0
while True:
# gradient step and prox
v = y - current_lr * grad
ply = prox.apply(v, current_lr)
# evaluate f at prox point
with torch.no_grad():
p.data.copy_(ply)
fply = closure(backward=False)
# Q(beta, y) = f(y) + <beta-y, ∇f(y)> + (1/2*lr)||beta-y||_2^2 + g(beta)
diff = ply - y
Q2 = torch.dot(grad.view(-1), diff.view(-1))
Q3 = (1/(2*current_lr)) * torch.dot(diff.view(-1), diff.view(-1))
Q = fy + Q2 + Q3
if fply <= Q:
break
elif ls >= max_ls:
with torch.no_grad():
p.data.copy_(y)
break
current_lr *= lr_decay
ls += 1
self.state[p]['lr'] = current_lr
# End of line search
# -------------------------------------------------
# -------------------------------------------------
# FISTA update step
if use_acc:
tkp = 0.5 * (1.0 + sqrt(1.0 + 4.0 * tk * tk))
momentum = p.data + ((tk - 1.0) / tkp) * (p.data - x_prev)
self.state[p]['x_prev'].copy_(p.data)
self.state[p]['y'] = momentum
self.state[p]['tk'] = tkp
else:
self.state[p]['x_prev'].copy_(p.data)
self.state[p]['y'] = p.data.clone()
self.state[p]['tk'] = 1.0
I’m trying to train a neural network with the following objective function:
Loss = MSE(y_pred, y_true) + λ * ||θ||₁
Where:
MSE
is the mean squared errorλ
is the regularization coefficientθ
is a subset of the network parameters (specifically the first layer weights and all biases)
From my testing, the optimizer doesn’t seem to be working as expected. I have a few specific concerns:
- The current implementation performs line search separately for each parameter, which seems inefficient and potentially problematic
- Parameters are updated immediately after their individual line search instead of all together
- I’m not sure if the approach scales well for neural networks with many parameters
Specific Questions:
-
Parameter Grouping: Should I flatten all regularized parameters into a single tensor for FISTA optimization? Or keep them separate?
-
Mixed Regularization: What’s the best way to handle the fact that I want to apply L1 regularization only to specific parameters (first layer weights and all biases) but not others?
-
Line Search Efficiency: The current implementation evaluates the model once per parameter during line search, which could be very inefficient for large networks. How can I optimize this?
-
Global vs. Local Learning Rate: Should I use a single learning rate for all parameters, or allow parameter-specific rates?
-
Convergence Issues: Are there any known issues with FISTA convergence in neural network training that might be affecting my implementation?
For reference, here’s my proximal operator for L1 regularization:
class L1Prox:
def __init__(self, lambda_):
self.lambda_ = lambda_
def apply(self, x, gamma):
"""Proximal operator for L1 norm: prox_{γλ||·||₁}(x) = soft_threshold(x, γλ)"""
threshold = gamma * self.lambda_
return torch.sign(x) * torch.clamp(torch.abs(x) - threshold, min=0)
Any insights, code improvements, or alternative approaches would be greatly appreciated!