I currently have a neural network module:
import torch.nn as nn class NN(nn.Module): def __init__(self,args,lambda_f,nn1, loss, opt): super().__init__() self.args = args self.lambda_f = lambda_f self.nn1 = nn1 self.loss = loss self.opt = opt # more nn.Params stuff etc... def forward(self, x): #some code using fields return out
I am trying to checkpoint it but because pytorch saves using
state_dicts it means I can’t save the lambda functions I was actually using if I checkpoint with the pytorch
torch.save etc. I literally want to save everything without issue and re-load to train on GPUs later. I currently am using this:
def save_ckpt(path_to_ckpt): from pathlib import Path import dill as pickle ## Make dir. Throw no exceptions if it already exists path_to_ckpt.mkdir(parents=True, exist_ok=True) ckpt_path_plus_path = path_to_ckpt / Path('db') ## Pickle args db['crazy_mdl'] = crazy_mdl with open(ckpt_path_plus_path , 'ab') as db_file: pickle.dump(db, db_file)
currently it throws no errors when I chekpoint it and it saved it.
I am worried that when I train it there might be a subtle bug even if no exceptions/errors are trained or something unexpected might happen (e.g. weird saving on disks in the clusters etc who knows).
Is this safe to do with pytorch classes/nn models? Especially if we want to resume training with GPUs?
- How does one pickle arbitrary pytorch models that use lambda functions?