Optimizer on multi-neural networks

Hello ,
I have created a neural network like this : F1 = feedforward netwrok --> RNN --> F2 = feedforward netwrok and I want to train them but I’m having issues with my optimizer

# MLP_in-----------------------------------------------------------------------

model_in = MLPModel(input_dim, args.hidden_dim_1_in, args.hidden_dim_2_in, args.hidden_dim_3_in, args.output_dim_in)

if torch.cuda.is_available():
        model_in.cuda()
        
# RNN--------------------------------------------------------------------------    
        
model_RNN = LSTMModel(input_dim, args.hidden_dim, args.layer_dim, args.output_dim_rnn)        
        
if torch.cuda.is_available():
        model_RNN.cuda()

# MLP_OUT-----------------------------------------------------------------------

model_out = MLPModel(args.output_dim_rnn, args.hidden_dim_1_out, args.hidden_dim_2_out, args.hidden_dim_3_out, args.output_dim)

if torch.cuda.is_available():
        model_out.cuda()

optimizer = torch.optim.SGD(model.parameters() , lr=learning_rate)  

should I replace model.parameters() with their sum ? or there is another method to do this ?
And thank you.

You could pass the parameters together with:

params = list(model_in.parameters()) + list(model_RNN.parameters()) + list(model_out.parameters())
optimizer = optim.SGD(params, lr=learning_rate)
2 Likes

I construct a optimizer like that, but it raise an error when I apply the function torch.nn.utils.clip_grad_norm_ to the optimizer. I have tried two ways,
params = list(model1.parameters())+list(model2.parameters())
optimizer = optim.SGD(params, lr=learning_rate)

  1. torch.nn.utils.clip_grad_norm_(optimizer,clip_value)
    the error message is
    TypeError: ‘SGD’ object is not iterable
    2)torch.nn.utils.clip_grad_norm_(optimizer,clip_value)
    AttributeError: ‘SGD’ object has no attribute ‘parameters’
    is there any way to do this?

clip_grad_norm expects parameters, not an optimizer.
Try to pass params or model.parameters() to it.

You are right. Thanks!