Hi,
I’m implementing a version of a proximal gradient descent algorithm. For this, I wrote an optimizer which takes a prox
callable as argument (e.g. a method which performs soft thresholding, clipping, or some projection/prox) ; this prox is saved as an attribute in the defaults
dict.
When I try to save the optimizer’s state dict, I get a pickling error: Attribute error: can't pickle local object 'prox'
.
Is there any way to get around this?
Code to reproduce this error: Google Colaboratory
Thanks for leaving the link to your notebook
I already had this error when defining one of the methods of my pytorch module as a lambda function, like self.proj = lambda x : x
.
When I changed it to self.proj = nn.Identity()
I didn’t get any error.
So the error could come from the fact that prox contains lambda functions.
You can change lambda x, s=None: x
for example to an nn.Identity()
.
Try to remove the lambda functions, and use pytorch module.
Note : I get this when executing your code AttributeError: Can't pickle local object 'PGD.__init__.<locals>.<lambda>'
1 Like
Ah, thanks @pascal_notsawo – I hadn’t thought of defining the prox
operations as nn.Module
s. That’s helpful, although will take some time to fix (I have a bunch of constraints defined chop/constraints.py at master · openopt/chop · GitHub).
If another way is possible, I’m all ears.
Here is the modification I made, and it works.
Since prox elements are only seen as functions in the end, there is no need to put them in a ModuleList. But you can still do it.
class f(torch.nn.Module) :
def __init__(self, prox_el) :
super().__init__()
self.prox_el = prox_el
def forward(self, x, s=None) :
return self.prox_el(x.unsqueeze(0)).squeeze()
class PGD(torch.optim.Optimizer):
"""Proximal Gradient Descent
Args:
params: [torch.Parameter]
List of parameters to optimize over
prox: [callable or None]
List of prox operators, one per parameter.
lr: float
Learning rate
momentum: float in [0, 1]
normalization: str
Type of gradient normalization to be used.
Possible values are 'none', 'L2', 'Linf', 'sign'.
"""
name = 'PGD'
POSSIBLE_NORMALIZATIONS = {'none', 'L2', 'Linf', 'sign'}
def __init__(self, params, prox=None, lr=.1, momentum=.9, normalization='none'):
if prox is None:
prox = [None] * len(list(params))
self.prox = []
for prox_el in prox:
if prox_el is not None:
self.prox.append(f(prox_el))
else:
self.prox.append(torch.nn.Identity())
#...
# ...
1 Like
This is awesome; so just subclassing nn.Module
and putting the content of the lambda function in the forward.
Thanks!
Yes, you are welcome
But my solution is only an avoidance technique, because I think we can simply (and it’s better) define them as simple functions when there are no (learnable) parameters. Except that serialization fails in this case, if we define the function in the ordinary way.
For example, if I define our function above as below the serialization will fail by returning an error like : AttributeError: Can't pickle local object 'f.<locals>.g'
def f(prox_el) :
def g(x, s=None) :
return prox_el(x.unsqueeze(0)).squeeze()
return g
On the other hand, if in my module I define an attribute like self.loss = F.mse_loss
, which here is only a function, the serialization will succeed.
So a solution would be to see how pytorch defines its own functions.
1 Like