I have a defined a method that builds a model, depending on a certain config (global variable):
def build_model():
if config.model_type == 'MLPa':
model = MLPa()
elif config.model_type == 'MLPb':
model = MLPb()
elif config.model_type == 'CNN':
model = CNN()
if torch.cuda.is_available():
model = model.cuda()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=config.lr0, weight_decay = config.weight_decay/config.batch_size)
return model, criterion, optimizer
MLPa is defined like so:
class MLPa(nn.Module):
def __init__(self):
super(MLPa, self).__init__()
self.model = nn.Sequential(
nn.Linear(784, 600),
nn.ReLU(),
nn.Linear(600, 200),
nn.ReLU(),
nn.Linear(200, 10)),
def forward(self, x):
output = self.model(x)
return output
The other model classes are similarly defined.
However, on running:
model, criterion, optimizer = build_model()
I recieve:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-12-b7697f348cad> in <module>()
----> 1 model, criterion, optimizer = build_model()
2 train_model(model, criterion, optimizer)
<ipython-input-8-077c576fd794> in build_model()
11
12 criterion = nn.CrossEntropyLoss()
---> 13 optimizer = optim.SGD(model.parameters(), lr=config.lr0, weight_decay = config.weight_decay/config.batch_size)
14 return model, criterion, optimizer
/anaconda/envs/py36/lib/python3.6/site-packages/torch/optim/sgd.py in __init__(self, params, lr, momentum, dampening, weight_decay, nesterov)
55 if nesterov and (momentum <= 0 or dampening != 0):
56 raise ValueError("Nesterov momentum requires a momentum and zero dampening")
---> 57 super(SGD, self).__init__(params, defaults)
58
59 def __setstate__(self, state):
/anaconda/envs/py36/lib/python3.6/site-packages/torch/optim/optimizer.py in __init__(self, params, defaults)
32 param_groups = list(params)
33 if len(param_groups) == 0:
---> 34 raise ValueError("optimizer got an empty parameter list")
35 if not isinstance(param_groups[0], dict):
36 param_groups = [{'params': param_groups}]
ValueError: optimizer got an empty parameter list
Can anybody suggest a solution please?
Also if there is a better way to build different models, depending on a config file, I would love to hear it.