Hello,
I’ve developed a custom FISTA (Fast Iterative Shrinkage-Thresholding Algorithm) optimizer in PyTorch for a project I’m working on. The optimizer works perfectly under normal circumstances. However, I’ve encountered a problem when trying to rollback the model parameters to a previous state during training.
import torch
import torch.nn as nn
class FISTA(torch.optim.Optimizer):
def __init__(self, params, lr, lambda_):
if lr < 0.0:
raise ValueError(f"Invalid learning rate: {lr} - should be >= 0.0")
if lambda_ < 0.0:
raise ValueError(f"Invalid lambda: {lambda_} - should be >= 0.0")
defaults = dict(lr=lr, lambda_=lambda_)
super(FISTA, self).__init__(params, defaults)
def shrinkage_operator(self, u, tresh):
return torch.sign(u) * torch.maximum(torch.abs(u) - tresh, torch.tensor(0.0, device=u.device))
def step(self, closure=None):
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
lr = group['lr']
lambda_ = group['lambda_']
state = self.state[p]
# State initialization
if len(state) == 0:
state = self.state[p]
state['x_prev'] = p.data
state['y_prev'] = p.data.clone()
state['t_prev'] = torch.tensor(1., device=p.device)
x_prev, y_prev, t_prev = state['x_prev'], state['y_prev'], state['t_prev']
x_next = self.shrinkage_operator(y_prev - lr * grad, lambda_)
t_next = (1. + torch.sqrt(1. + 4. * t_prev ** 2)) / 2.
y_next = x_next + ((t_prev - 1) / t_next) * (x_next - x_prev)
state['x_prev'], state['y_prev'], state['t_prev'] = x_next, y_next, t_next
p.data.copy_(x_next)
return loss
This optimizer is then used in my model train loop:
optimizer_penalized = FISTA(params=layer1.parameters(), lambda_=lambda_, lr=lr)
optimizer_unpenalized = FISTA(params=layer2.parameters(), lambda_=0.0, lr=lr)
I need two of them because I only want the first layer of my model to be shrinked. Now, this optimizer works perfectly fine until I modify layer1 and layer2. Let’s say I train my model and at every epoch creating a lower cost value, I save the layers like so :
layer1_before_dict = (layer1.weight.data.clone().detach(), layer1.bias.data.clone().detach())
layer2_before_dict = (layer2.weight.data.clone().detach(), layer2.bias.data.clone().detach())
and when an epoch increases the cost, I call
with torch.no_grad():
layer1.weight.copy_(layer1_before_dict[0])
layer1.bias.copy_(layer1_before_dict[1])
layer2.weight.copy_(layer2_before_dict[0])
layer2.bias.copy_(layer2_before_dict[1])
The issue arises here: after rolling back the layer parameters, the changes don’t seem to reflect within the optimizer’s internal state. As a result, the model doesn’t actually return to its previous state as intended.
Any help would be great ! It’s been 3 days of no sleep …