Issues with Custom FISTA Optimizer and Model State Rollback in PyTorch

Hello,

I’ve developed a custom FISTA (Fast Iterative Shrinkage-Thresholding Algorithm) optimizer in PyTorch for a project I’m working on. The optimizer works perfectly under normal circumstances. However, I’ve encountered a problem when trying to rollback the model parameters to a previous state during training.

import torch 
import torch.nn as nn

class FISTA(torch.optim.Optimizer):
    def __init__(self, params, lr, lambda_):
        if lr < 0.0:
            raise ValueError(f"Invalid learning rate: {lr} - should be >= 0.0")
        if lambda_ < 0.0:
            raise ValueError(f"Invalid lambda: {lambda_} - should be >= 0.0")
        
        defaults = dict(lr=lr, lambda_=lambda_)
        super(FISTA, self).__init__(params, defaults)
    
    def shrinkage_operator(self, u, tresh):
        return torch.sign(u) * torch.maximum(torch.abs(u) - tresh, torch.tensor(0.0, device=u.device))

    def step(self, closure=None):
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad.data
                lr = group['lr']
                lambda_ = group['lambda_']
                state = self.state[p]
                
                # State initialization
                if len(state) == 0:
                    state = self.state[p]
                    state['x_prev'] = p.data
                    state['y_prev'] = p.data.clone()
                    state['t_prev'] = torch.tensor(1., device=p.device)

                x_prev, y_prev, t_prev = state['x_prev'], state['y_prev'], state['t_prev']

                x_next = self.shrinkage_operator(y_prev - lr * grad, lambda_)
                t_next = (1. + torch.sqrt(1. + 4. * t_prev ** 2)) / 2.
                y_next = x_next + ((t_prev - 1) / t_next) * (x_next - x_prev)

                state['x_prev'], state['y_prev'], state['t_prev'] = x_next, y_next, t_next

                p.data.copy_(x_next)

        return loss

This optimizer is then used in my model train loop:

optimizer_penalized = FISTA(params=layer1.parameters(), lambda_=lambda_, lr=lr)
optimizer_unpenalized = FISTA(params=layer2.parameters(), lambda_=0.0, lr=lr)

I need two of them because I only want the first layer of my model to be shrinked. Now, this optimizer works perfectly fine until I modify layer1 and layer2. Let’s say I train my model and at every epoch creating a lower cost value, I save the layers like so :

layer1_before_dict = (layer1.weight.data.clone().detach(), layer1.bias.data.clone().detach())
layer2_before_dict = (layer2.weight.data.clone().detach(), layer2.bias.data.clone().detach())

and when an epoch increases the cost, I call

with torch.no_grad():
                layer1.weight.copy_(layer1_before_dict[0])
                layer1.bias.copy_(layer1_before_dict[1])
                layer2.weight.copy_(layer2_before_dict[0])
                layer2.bias.copy_(layer2_before_dict[1])

The issue arises here: after rolling back the layer parameters, the changes don’t seem to reflect within the optimizer’s internal state. As a result, the model doesn’t actually return to its previous state as intended.

Any help would be great ! It’s been 3 days of no sleep … :slight_smile: :sweat_smile:

That’s weird. Do you see the same issue if you restore the previous weights by loading the state_dict()?

Yes I tried it as well. I figured the problem is that my X_prev and stuff are not moving back to a previous state when I update my layers. Maybe I should add a method that updates the optimiser parameters?