Hello,
I have a network architecture like below that chooses different options based an input argument.
When I try to save the model at a certain epoch during training using torch.save() I get:
AttributeError: Can't pickle local object 'main.<locals>.Net1'
Here is the part that I design the network in the main():
IE_dim = X_tr.shape[1]
if args.hd == 1:
class Net1(nn.Module):
def __init__(self, args):
super(Net1, self).__init__()
self.features = torch.nn.Sequential(
nn.Dropout(args.idr),
nn.Linear(IE_dim, 512),
nn.Tanh(),
nn.Dropout(args.ldr),
nn.Linear(512, 1))
def forward(self, x):
out = self.features(x)
return out
Model = Net1(args)
elif args.hd == 2:
class Net2(nn.Module):
def __init__(self, args):
super(Net2, self).__init__()
self.features = torch.nn.Sequential(
nn.Dropout(args.idr),
nn.Linear(IE_dim, 256),
nn.Tanh(),
nn.Dropout(args.ldr),
nn.Linear(256, 256),
nn.Tanh(),
nn.Dropout(args.ldr),
nn.Linear(256, 1))
def forward(self, x):
out = self.features(x)
return out
Model = Net2(args)
elif args.hd == 3:
class Net3(nn.Module):
def __init__(self, args):
super(Net3, self).__init__()
self.features = torch.nn.Sequential(
nn.Dropout(args.idr),
nn.Linear(IE_dim, 128),
nn.Tanh(),
nn.Dropout(args.ldr),
nn.Linear(128, 128),
nn.Tanh(),
nn.Dropout(args.ldr),
nn.Linear(128, 128),
nn.Tanh(),
nn.Dropout(args.ldr),
nn.Linear(128, 1))
def forward(self, x):
out = self.features(x)
return out
Model = Net3(args)
and here is how I use torch.save():
torch.save(Model, os.path.join(SOME PATH, 'Best_Model.pt'))
I had that piece of code first in a separate .py file and used import * in the main() but got the above error and then, I moved the code to the main() function but got the same error.
I appreciate any help!