How does one pickle arbitrary pytorch models that use lambda functions?

I currently have a neural network module:

import torch.nn as nn

class NN(nn.Module):
    def __init__(self,args,lambda_f,nn1, loss, opt):
        super().__init__()
        self.args = args
        self.lambda_f = lambda_f
        self.nn1 = nn1
        self.loss = loss
        self.opt = opt
        # more nn.Params stuff etc...

    def forward(self, x):
        #some code using fields
        return out

I am trying to checkpoint it but because pytorch saves using state_dicts it means I can’t save the lambda functions I was actually using if I checkpoint with the pytorch torch.save etc. I literally want to save everything without issue and re-load to train on GPUs later. I currently am using this:

def save_ckpt(path_to_ckpt):
    from pathlib import Path
    import dill as pickle
    ## Make dir. Throw no exceptions if it already exists
    path_to_ckpt.mkdir(parents=True, exist_ok=True)
    ckpt_path_plus_path = path_to_ckpt / Path('db')

    ## Pickle args
    db['crazy_mdl'] = crazy_mdl
    with open(ckpt_path_plus_path , 'ab') as db_file:
        pickle.dump(db, db_file)

currently it throws no errors when I chekpoint it and it saved it.

I am worried that when I train it there might be a subtle bug even if no exceptions/errors are trained or something unexpected might happen (e.g. weird saving on disks in the clusters etc who knows).

Is this safe to do with pytorch classes/nn models? Especially if we want to resume training with GPUs?


Cross posted:


related: Advantages & Disadvantages of using pickle module to save models vs torch.save

1 Like

In case this is still an issue, we have a (partly satisfactory) solution here.