Using nn.Module vs Sequential - Optimizer does not do the same thing

Hello everyone,

I am new to PyTorch, so please excuse me if my problem is trivial :). At the moment I am trying to rewrite a code implementing the reinforcement learning algorithm Soft Actor Critic (SAC) together with an LSTM-layer shared by all the neural nets involved. My version of the code is exactly the same as the original one, apart from the fact that I created an nn.Module class for every type of neural network (e.g. SoftQNetwork, ValueNetwork, PolicyNetwork). The original code uses Sequential.

As optimizer I am using a single Adam optimizer for the parameters of all networks. The problem is that the original code works well (I am currently training pendulum) whereas my code does not seem to work.

Does anyone have an idea what I might be doing wrong?

Thank you in advance!

The original code:

class SACRNN(nn.Module):
    def __init__(self,
                 input_size,
                 action_size,
                 gamma=0.99,
                 reward_scale=1,
                 alpha=1.0,
                 shared_layers=256,
                 output_layers=256,
                 lr=3e-4):
        super(SACRNN, self).__init__()

        ....
        self.rnn = nn.LSTM(input_size=self.input_size, hidden_size=self.shared_layers,
                           num_layers=1, batch_first=False, bias=True)

        self.out_s2mua = nn.Sequential(nn.Linear(self.shared_layers, self.output_layers),
                                       nn.ReLU(),
                                       nn.Linear(self.output_layers, self.output_layers),
                                       nn.ReLU(),
                                       nn.Linear(self.output_layers, self.action_size))

        self.out_s2log_siga = nn.Sequential(nn.Linear(self.shared_layers, self.output_layers),
                                        nn.ReLU(),
                                        nn.Linear(self.output_layers, self.output_layers),
                                        nn.ReLU(),
                                        nn.Linear(self.output_layers, self.action_size))

        self.out_s2v = nn.Sequential(nn.Linear(self.shared_layers, self.output_layers),
                                     nn.ReLU(),
                                     nn.Linear(self.output_layers, self.output_layers),
                                     nn.ReLU(),
                                     nn.Linear(self.output_layers, 1))

        self.out_sa2q1 = nn.Sequential(nn.Linear(self.shared_layers + self.action_size, self.output_layers),
                                       nn.ReLU(),
                                       nn.Linear(self.output_layers, self.output_layers),
                                       nn.ReLU(),
                                       nn.Linear(self.output_layers, 1))

        self.out_sa2q2 = nn.Sequential(nn.Linear(self.shared_layers + self.action_size, self.output_layers),
                                       nn.ReLU(),
                                       nn.Linear(self.output_layers, self.output_layers),
                                       nn.ReLU(),
                                       nn.Linear(self.output_layers, 1))

        self.out_s2v_tar = nn.Sequential(nn.Linear(self.shared_layers, self.output_layers),
                                         nn.ReLU(),
                                         nn.Linear(self.output_layers, self.output_layers),
                                         nn.ReLU(),
                                         nn.Linear(self.output_layers, 1))
        # synchronizing target V network and V network

        state_dict_tar = self.out_s2v_tar.state_dict()
        state_dict = self.out_s2v.state_dict()
        for key in list(self.out_s2v.state_dict().keys()):
            state_dict_tar[key] = state_dict[key]
        self.out_s2v_tar.load_state_dict(state_dict_tar)

        self.optimizer_v = torch.optim.Adam(self.parameters(), lr=lr)

......

        self.optimizer_v.zero_grad()
        total_loss.backward()
        self.optimizer_v.step()
 

My code

class LSTMNetwork(nn.Module):
    def __init__(self,
                 input_size=3,
                 shared_layers=256,
                 ):
        
        super().__init__()
        self.layers = nn.LSTM(input_size=input_size, hidden_size=shared_layers,
                           num_layers=1, batch_first=False, bias=True)

    def forward(self, state, hidden=None):
        if (hidden != None):
            lstm_state, lstm_hidden = self.layers(state, hidden)
        else:
            lstm_state, lstm_hidden = self.layers(state)

        return lstm_state, lstm_hidden

class ValueNetwork(nn.Module):
    def __init__(self,
                 shared_layers=256,
                 output_layers=256,
                 ):
        
        super().__init__()
        self.layers = nn.Sequential(nn.Linear(shared_layers, output_layers),
                                    nn.ReLU(),
                                    nn.Linear(output_layers, output_layers),
                                    nn.ReLU(),
                                    nn.Linear(output_layers, 1))

    def forward(self, lstm_state):
        x = self.layers(lstm_state)
        return x

class SoftQNetwork(nn.Module):
    def __init__(self,
                action_size,
                shared_layers=256,
                output_layers=256
                ):

        super().__init__()
        self.layers = nn.Sequential(nn.Linear(shared_layers + action_size, output_layers),
                                    nn.ReLU(),
                                    nn.Linear(output_layers, output_layers),
                                    nn.ReLU(),
                                    nn.Linear(output_layers, 1))

    def forward(self, lstm_state, action):
        x = self.layers(torch.cat((lstm_state, action), dim=-1))
        return x

class PolicyNetwork(nn.Module):
    def __init__(self,
                action_size,
                shared_layers=256,
                output_layers=256,
                ):
        super().__init__()
        self.layers = nn.Sequential(nn.Linear(shared_layers, output_layers),
                                    nn.ReLU(),
                                    nn.Linear(output_layers, output_layers),
                                    nn.ReLU(),
                                    nn.Linear(output_layers, action_size))
    
    def forward(self, lstm_state):
        x = self.layers(lstm_state)
        return x

class SACRNN():
    def __init__(self,
                 input_size,
                 action_size,
                 gamma=0.99,
                 reward_scale=1,
                 alpha=1.0,
                 shared_layers=256,
                 output_layers=256,
                 tau=0.005,
                 lr=3e-4):

.....
        self.rnn = LSTMNetwork(self.input_size, self.shared_layers) 

        self.out_s2mua = PolicyNetwork(self.action_size, self.shared_layers, self.output_layers)
        self.out_s2log_siga = PolicyNetwork(self.action_size, self.shared_layers, self.output_layers)

        self.out_s2v = ValueNetwork(self.shared_layers, self.output_layers)
        self.out_s2v_tar = ValueNetwork(self.shared_layers, self.output_layers)

        self.out_sa2q1 = SoftQNetwork(self.action_size, self.shared_layers, self.output_layers)
        self.out_sa2q2 = SoftQNetwork(self.action_size, self.shared_layers, self.output_layers)

        # synchronizing target V network and V network
        state_dict_tar = self.out_s2v_tar.state_dict()
        state_dict = self.out_s2v.state_dict()
        for key in list(self.out_s2v.state_dict().keys()):
            state_dict_tar[key] = state_dict[key]
        self.out_s2v_tar.load_state_dict(state_dict_tar)

        all_params = [
            *self.rnn.layers.parameters(),
            *self.out_s2mua.layers.parameters(),
            *self.out_s2log_siga.layers.parameters(),          
            *self.out_s2v.layers.parameters(),
            *self.out_sa2q1.layers.parameters(),
            *self.out_sa2q2.layers.parameters(),
            *self.out_s2v_tar.layers.parameters(),           
            ]
        
        self.optimizer_v = torch.optim.Adam(all_params, lr=lr)

....
        self.optimizer_v.zero_grad()
        total_loss.backward()
        self.optimizer_v.step()

As a first debugging step you could make sure that the same parameters are properly registered and returned in model.parameters() in both approaches. Once this is done, make sure you are not freezing some parameters differently (if that’s even used). If this is also verified, I’m unsure why the optimizer would not perform the same updates, so you would need to explain this concern a bit more.

Thank you very much for your response! I tried the things you suggested and you were right, the problem wasn’t the optimizer but a typo I hadn’t realized existed.