Pickling an optimizer w/ callable attributes


I’m implementing a version of a proximal gradient descent algorithm. For this, I wrote an optimizer which takes a prox callable as argument (e.g. a method which performs soft thresholding, clipping, or some projection/prox) ; this prox is saved as an attribute in the defaults dict.

When I try to save the optimizer’s state dict, I get a pickling error: Attribute error: can't pickle local object 'prox' .

Is there any way to get around this?

Code to reproduce this error: Google Colaboratory

Thanks for leaving the link to your notebook

I already had this error when defining one of the methods of my pytorch module as a lambda function, like self.proj = lambda x : x.
When I changed it to self.proj = nn.Identity() I didn’t get any error.
So the error could come from the fact that prox contains lambda functions.

You can change lambda x, s=None: x for example to an nn.Identity().

Try to remove the lambda functions, and use pytorch module.

Note : I get this when executing your code AttributeError: Can't pickle local object 'PGD.__init__.<locals>.<lambda>'

1 Like

Ah, thanks @pascal_notsawo – I hadn’t thought of defining the prox operations as nn.Modules. That’s helpful, although will take some time to fix (I have a bunch of constraints defined chop/constraints.py at master · openopt/chop · GitHub).

If another way is possible, I’m all ears.

Here is the modification I made, and it works.
Since prox elements are only seen as functions in the end, there is no need to put them in a ModuleList. But you can still do it.

class f(torch.nn.Module) :
    def __init__(self, prox_el) :
        self.prox_el = prox_el

    def forward(self, x, s=None) :
        return self.prox_el(x.unsqueeze(0)).squeeze()

class PGD(torch.optim.Optimizer):
    """Proximal Gradient Descent
      params: [torch.Parameter]
        List of parameters to optimize over
      prox: [callable or None]
        List of prox operators, one per parameter.
      lr: float
        Learning rate
      momentum: float in [0, 1]
      normalization: str
        Type of gradient normalization to be used.
        Possible values are 'none', 'L2', 'Linf', 'sign'.
    name = 'PGD'
    POSSIBLE_NORMALIZATIONS = {'none', 'L2', 'Linf', 'sign'}

    def __init__(self, params, prox=None, lr=.1, momentum=.9, normalization='none'):
        if prox is None:
            prox = [None] * len(list(params))

        self.prox = []
        for prox_el in prox:
            if prox_el is not None:
        # ...
1 Like

This is awesome; so just subclassing nn.Module and putting the content of the lambda function in the forward.

Thanks! :partying_face:

Yes, you are welcome

But my solution is only an avoidance technique, because I think we can simply (and it’s better) define them as simple functions when there are no (learnable) parameters. Except that serialization fails in this case, if we define the function in the ordinary way.

For example, if I define our function above as below the serialization will fail by returning an error like : AttributeError: Can't pickle local object 'f.<locals>.g'

def f(prox_el) :
  def g(x, s=None) :
    return prox_el(x.unsqueeze(0)).squeeze()
  return g

On the other hand, if in my module I define an attribute like self.loss = F.mse_loss, which here is only a function, the serialization will succeed.

So a solution would be to see how pytorch defines its own functions.

1 Like