Torch.save and AttributeError: Can't pickle local object

Hello,
I have a network architecture like below that chooses different options based an input argument.
When I try to save the model at a certain epoch during training using torch.save() I get:

AttributeError: Can't pickle local object 'main.<locals>.Net1'

Here is the part that I design the network in the main():

IE_dim = X_tr.shape[1]        
    if args.hd == 1:
        class Net1(nn.Module):
            def __init__(self, args):
                super(Net1, self).__init__()

                self.features = torch.nn.Sequential(
                    nn.Dropout(args.idr),
                    nn.Linear(IE_dim, 512),
                    nn.Tanh(),
                    nn.Dropout(args.ldr),
                    nn.Linear(512, 1))

            def forward(self, x):
                out = self.features(x)
                return out  
        Model = Net1(args)

    elif args.hd == 2: 
        class Net2(nn.Module):
            def __init__(self, args):
                super(Net2, self).__init__()

                self.features = torch.nn.Sequential(
                    nn.Dropout(args.idr),
                    nn.Linear(IE_dim, 256),
                    nn.Tanh(),
                    nn.Dropout(args.ldr),
                    nn.Linear(256, 256),
                    nn.Tanh(),
                    nn.Dropout(args.ldr),
                    nn.Linear(256, 1)) 

            def forward(self, x):
                out = self.features(x)
                return out            
        Model = Net2(args)

    elif args.hd == 3:
        class Net3(nn.Module):
            def __init__(self, args):
                super(Net3, self).__init__()    

                self.features = torch.nn.Sequential(
                    nn.Dropout(args.idr),
                    nn.Linear(IE_dim, 128),
                    nn.Tanh(),
                    nn.Dropout(args.ldr),
                    nn.Linear(128, 128),
                    nn.Tanh(),
                    nn.Dropout(args.ldr),
                    nn.Linear(128, 128),
                    nn.Tanh(),
                    nn.Dropout(args.ldr),
                    nn.Linear(128, 1))                        

            def forward(self, x):
                out = self.features(x)
                return out
        Model = Net3(args)        

and here is how I use torch.save():

torch.save(Model, os.path.join(SOME PATH, 'Best_Model.pt'))

I had that piece of code first in a separate .py file and used import * in the main() but got the above error and then, I moved the code to the main() function but got the same error.
I appreciate any help!