Layer-wise learning rate does not function the same way as a single learning rate


(Sharafath Mohammed) #1

I am working on a reinforcement learning which can adapt to dynamic targets with a neural network as it function approximator. I want the learning rate of certain layers to be ‘zero’ after the initial training is complete. The first training happens with the optimizer like so,

self.optimizer = torch.optim.Adam(self.parameters(),
                                          lr=self.learning_rate)

This happens during the initialisation. But after the training is done when i try to set the layer-wise learning rates using,

a2c.network1.optimizer = Adam([
                    {"params": a2c.network1.body.parameters(), "lr": SLR[j]},
                    {"params": a2c.network1.policy.parameters(), "lr": SLR[j]},
                    {"params": a2c.network1.value.parameters(), "lr": SLR[j]},
        lr = SLR[j],
                ])

the results are very bad. to see if the problem is with my code i tried to run the entire program with the the above set optimiser during initialisation and the agents performed very badly with the same learning rate as the one set with a universal learning rate. I think there could be some issues with the my implementation of the code. Can anyone help me see what I am doing wrong?

The full code of the network, the environment and the main body is given below:

class actorCriticNet(nn.Module):
    def __init__(self, n_hidden_layers, n_hidden_nodes,
                 learning_rate, bias=False, device='cpu'):
        super(actorCriticNet, self).__init__()
        
        self.device = device
        self.n_inputs = 30
        self.n_outputs = 10
        self.n_hidden_nodes = n_hidden_nodes
        self.n_hidden_layers = n_hidden_layers
        self.learning_rate = learning_rate
        self.bias = bias
        self.action_space = np.arange(self.n_outputs)
        
        # Generate network according to hidden layer and node settings
        self.layers = OrderedDict()
        self.n_layers = 2 * self.n_hidden_layers
        for i in range(self.n_layers + 1):
            # Define single linear layer
            if self.n_hidden_layers == 0:
                self.layers[str(i)] = nn.Linear(
                    self.n_inputs,
                    self.n_outputs,
                    bias=self.bias)
            # Define input layer for multi-layer network
            elif i % 2 == 0 and i == 0 and self.n_hidden_layers != 0:
                self.layers[str(i)] = nn.Linear( 
                    self.n_inputs, 
                    self.n_hidden_nodes,
                    bias=self.bias)
            # Define intermediate hidden layers
            elif i % 2 == 0 and i != 0:
                self.layers[str(i)] = nn.Linear(
                    self.n_hidden_nodes,
                    self.n_hidden_nodes,
                    bias=self.bias)
            else:
                self.layers[str(i)] = nn.ReLU()
                
        self.body = nn.Sequential(self.layers)
            
        # Define policy head
        self.policy = nn.Sequential(
            nn.Linear(self.n_hidden_nodes,
                      self.n_hidden_nodes,
                      bias=self.bias),
            nn.ReLU(),
            nn.Linear(self.n_hidden_nodes,
                      self.n_outputs,
                      bias=self.bias))
        # Define value head
        self.value = nn.Sequential(
            nn.Linear(self.n_hidden_nodes,
                      self.n_hidden_nodes,
                      bias=self.bias),
            nn.ReLU(),
            nn.Linear(self.n_hidden_nodes,
                      1, 
                      bias=self.bias))

        if self.device == 'cuda':
            self.net1.cuda()
            self.net2.cuda()

        self.optimizer = torch.optim.Adam(self.parameters(),
                                          lr=self.learning_rate)
            
    def predict(self, state):
        body_output = self.get_body_output(state)
        probs = F.softmax(self.policy(body_output), dim=-1)
        return probs, self.value(body_output)

    def get_body_output(self, state):
        state_t = torch.FloatTensor(state).to(device=self.device)
        return self.body(state_t)
    
    def get_action(self, state):
        probs = self.predict(state)[0][0][0].detach().numpy()
        action = np.random.choice(self.action_space, p=probs)
        return action
    
    def get_log_probs(self, state):
        body_output = self.get_body_output(state)
        logprobs = F.log_softmax(self.policy(body_output), dim=-1)
        return logprobs
class A2C():
    global input1, input2, output1, output2, tgt, tgt_list
            
    
    def __init__(self, network1, network2):
        
        self.network1 = network1
        self.network2 = network2
        self.action_space = np.arange(100)
        self.target1 = 20
        self.target2 = 20
        
    def generate_episode(self):
        states, actions1, actions2, rewards1, rewards2, dones, next_states = [], [], [], [], [], [], []
        s_onehot = torch.FloatTensor(1,27)
        ns_onehot = torch.FloatTensor(1,27)
        counter = 0
        best_count = 1000
        total_count = self.batch_size * self.n_steps
        best_actions = []
        best_states = []
        nn_inp = []
        nn_ninp = []
        inp1_list = []
        inp2_list = []
        op1_list = []
        op2_list = []
        inp_percent = torch.FloatTensor([[1+((self.target1-self.target2)/max(self.target1, self.target2))]])

        
        while counter < total_count:
            done = False
            input1 = 0
            input2 = 0
            output1 = 0
            output2 = 0
            tgt = 0
            cnt = 0
            ind = 0
            st = [] # internal states list
            ac = [] #internal actions list
            self.s_0 = torch.LongTensor([[0]])
            s_1 = torch.LongTensor([[0]])
            
            while done == False:
                s_onehot = torch.FloatTensor(1,27)
                ns_onehot = torch.FloatTensor(1,27)
        
                s_onehot.zero_()
                s_onehot.scatter_(1,self.s_0,1)
                in1_p = output1/self.target1
                in2_p = output2/self.target2
                in_p = torch.FloatTensor([[in1_p,in2_p]])
                s_onehot = torch.cat((s_onehot,in_p),1)
                s_onehot = torch.cat((s_onehot,inp_percent),1)
#                 s_onehot = (s_onehot - torch.mean(s_onehot))/ torch.std(s_onehot)
                
                ind1 = self.network1.get_action(s_onehot.unsqueeze(0))
                ind2 = self.network2.get_action(s_onehot.unsqueeze(0))
                ind = ind1*10+ind2
                ind = torch.from_numpy(np.array(ind,dtype='float32'))
                ac_in = action_index_data[np.where(action_index_data[:,0]==ind),:]
                s_1, i1, i2, o1, o2, r1, r2= next_state(self.s_0[0][0].tolist(),ac_in[0][0][1].tolist(),ac_in[0][0][2].tolist())
                
                cnt+=1
                st.append(self.s_0)
                ac.append(ind)
                
                input1 = input1+i1
                input2 = input2+i2
                output1 = output1+o1
                output2 = output2+o2

                ns_onehot.zero_()
                ns_onehot.scatter_(1,torch.LongTensor([[s_1]]),1)
                nin1_p = output1/self.target1
                nin2_p = output2/self.target2
                nin_p = torch.FloatTensor([[nin1_p,nin2_p]])
                ns_onehot = torch.cat((ns_onehot,nin_p),1)
                ns_onehot = torch.cat((ns_onehot,inp_percent),1)
#                 ns_onehot = (ns_onehot - torch.mean(ns_onehot))/ torch.std(ns_onehot)
                
                if (s_1==torch.FloatTensor([[14]]) or s_1==torch.FloatTensor([[17]])):
                    done = True
                    r1 = r1+torch.FloatTensor([-100])
                    r2 = r2+torch.FloatTensor([-100])
                    
                if (in1_p>=1 and in2_p>=1):
                    r1 = r1+torch.FloatTensor([500])
                    r2 = r2+torch.FloatTensor([500])
                    
                    if (cnt<best_count):
                        best_count = cnt
                    done = True
                    tgt = 1

                if ((in1_p<0.75 and in2_p>=1) or (in2_p<0.75 and in1_p>=1)):
                    done = True
                    r1 = r1+torch.FloatTensor([-30])
                    r2 = r2+torch.FloatTensor([-30])
                    
                if (counter>=1000):
                    done = True
                    r1 = r1+torch.FloatTensor([-100])
                    r2 = r2+torch.FloatTensor([-100])
                    
                self.reward1 +=r1
                self.reward2 +=r2
                states.append(self.s_0)
                next_states.append(s_1)
                actions1.append(ind1)
                actions2.append(ind2)
                rewards1.append(r1)
                rewards2.append(r2)
                dones.append(done)
                tgt_list.append(tgt)
                nn_inp.append(s_onehot)
                nn_ninp.append(ns_onehot)
                self.s_0 = torch.LongTensor([[s_1]])
    
                if done:
                    self.ep_rewards1.append(self.reward1)
                    self.ep_rewards2.append(self.reward2)
                    self.s_0 = torch.LongTensor([[0]])
                    self.reward = 0
                    self.ep_counter += 1
                    inp1_list.extend([input1]); inp2_list.extend([input2]);op1_list.extend([output1]);op2_list.extend([output2]);
                    
                    if self.ep_counter >= self.num_episodes:
                        counter = total_count
                        break
                
                counter += 1
                if counter >= total_count:
                    break
        return states, actions1, actions2, rewards1, rewards2, dones, next_states, inp1_list, inp2_list, op1_list, op2_list,\
        best_count, best_actions, best_states, nn_inp, nn_ninp
    
    def calc_rewards1(self, batch):
        batch1, batch2 = batch
        
        states, actions1, actions2, rewards1, rewards2, dones, next_states = batch1
        nn_inp, nn_ninp = batch2
        
        rewards1 = np.array(rewards1)
        total_steps = len(rewards1)
        target = 20
        
        nn_inp = torch.stack(nn_inp)
        state_values = self.network1.predict(nn_inp)[1]

        nn_ninp = torch.stack(nn_ninp)
        next_state_values = self.network1.predict(nn_ninp)[1]
        
        done_mask = torch.ByteTensor(dones).to(self.network1.device)
        next_state_values[done_mask] = 0.0
        state_values = state_values.detach().numpy().flatten()
        next_state_values = next_state_values.detach().numpy().flatten()
        
        G = np.zeros_like(rewards1, dtype=np.float32)
        td_delta = np.zeros_like(rewards1, dtype=np.float32)
        dones = np.array(dones)
        
        for t in range(total_steps):
            last_step = min(self.n_steps, total_steps - t)
            
            # Look for end of episode
            check_episode_completion = dones[t:t+last_step]
            if check_episode_completion.size > 0:
                if True in check_episode_completion:
                    next_ep_completion = np.where(check_episode_completion == True)[0][0]
                    last_step = next_ep_completion
            
            # Sum and discount rewards
            G[t] = sum([rewards1[t+n:t+n+1] * self.gamma ** n for 
                        n in range(last_step)])
        
        if total_steps > self.n_steps:
            G[:total_steps - self.n_steps] += next_state_values[self.n_steps:] \
                * self.gamma ** self.n_steps
        td_delta = G - state_values
        return G, td_delta
        
    def calc_rewards2(self, batch):
        batch1, batch2 = batch
        
        states, actions1, actions2, rewards1, rewards2, dones, next_states = batch1
        nn_inp, nn_ninp = batch2
        
        rewards2 = np.array(rewards2)
        total_steps = len(rewards2)
        #target = 20
        
        nn_inp = torch.stack(nn_inp)
        state_values = self.network2.predict(nn_inp)[1]

        nn_ninp = torch.stack(nn_ninp)
        next_state_values = self.network2.predict(nn_ninp)[1]
        
        done_mask = torch.ByteTensor(dones).to(self.network2.device)
        next_state_values[done_mask] = 0.0
        state_values = state_values.detach().numpy().flatten()
        next_state_values = next_state_values.detach().numpy().flatten()
        
        G = np.zeros_like(rewards2, dtype=np.float32)
        td_delta = np.zeros_like(rewards2, dtype=np.float32)
        dones = np.array(dones)
        
        for t in range(total_steps):
            last_step = min(self.n_steps, total_steps - t)
            
            # Look for end of episode
            check_episode_completion = dones[t:t+last_step]
            if check_episode_completion.size > 0:
                if True in check_episode_completion:
                    next_ep_completion = np.where(check_episode_completion == True)[0][0]
                    last_step = next_ep_completion
            
            # Sum and discount rewards
            G[t] = sum([rewards2[t+n:t+n+1] * self.gamma ** n for 
                        n in range(last_step)])
        
        if total_steps > self.n_steps:
            G[:total_steps - self.n_steps] += next_state_values[self.n_steps:] \
                * self.gamma ** self.n_steps
        td_delta = G - state_values
        return G, td_delta
        
    def train(self, target1, target2, n_steps=5, batch_size=10, num_episodes=2000, 
              gamma=0.99, beta=1-3, zeta=0.5):
        self.n_steps = n_steps
        self.gamma = gamma
        self.num_episodes = num_episodes
        self.beta = beta
        self.zeta = zeta
        self.batch_size = batch_size
        self.target1 = target1
        self.target2 = target2
        
        # Set up lists to log data
        self.ep_rewards1 = []
        self.ep_rewards2 = []
        self.kl_div = []
        self.policy_loss_1 = []
        self.value_loss_1 = []
        self.entropy_loss_1 = []
        self.policy_loss_2 = []
        self.value_loss_2 = []
        self.entropy_loss_2 = []
        #self.total_policy_loss1 = []
        #self.total_value_loss1 = []
        #self.total_policy_loss2 = []
        #self.total_value_loss2 = []
        #self.total_loss1 = []
        #self.total_loss2 = []
        
        self.s_0 = torch.LongTensor([[0]])
        self.reward1 = 0
        self.reward2 = 0
        self.ep_counter = 0
        b_count = 1000
        b_actions = []
        b_states = []
        in1 = []
        in2 = []
        out1 = []
        out2 = []
        
        while self.ep_counter < num_episodes:
            
            batch = self.generate_episode()
            nn_inp= batch[14]
            G1, td_delta1 = self.calc_rewards1([batch[0:7], batch[14:]])
            G2, td_delta2 = self.calc_rewards2([batch[0:7], batch[14:]])
            states = batch[0]
            
            if (batch[11]<b_count):
                b_count = batch[11]
                b_actions = batch[12]
                b_states = batch[13]
                
            in1.append(batch[7])
            in2.append(batch[8])
            out1.append(batch[9])
            out2.append(batch[10])
            actions1 = batch[1]
            actions2 = batch[2]
            
            nn_inp = torch.stack(nn_inp)
            self.update1(nn_inp, actions1, G1, td_delta1)
            self.update2(nn_inp, actions2, G2, td_delta2)
            
            #current_probs = self.network.predict(nn_inp)[0].detach().numpy()
            #new_probs = self.network.predict(nn_inp)[0].detach().numpy()
            #kl = -np.sum(current_probs * np.log(new_probs / current_probs))                
            #self.kl_div.append(kl)
            print("\rTarget achieved:{:d} Episode: {:d}, Best Count: {:d} ".format(
                sum(tgt_list), self.ep_counter, b_count), end="")
            
            batch = []
        
        return in1, in2, out1, out2, sum(tgt_list), b_count, b_actions, b_states, self.ep_rewards1, self.ep_rewards2,
            
    def plot_results(self):
        avg_rewards = [np.mean(self.ep_rewards[i:i + self.batch_size]) 
                       if i > self.batch_size 
            else np.mean(self.ep_rewards[:i + 1]) for i in range(len(self.ep_rewards))]

        plt.figure(figsize=(15,10))
        gs = gridspec.GridSpec(3, 2)
        ax0 = plt.subplot(gs[0,:])
        ax0.plot(self.ep_rewards)
        ax0.plot(avg_rewards)
        ax0.set_xlabel('Episode')
        plt.title('Rewards')

        ax1 = plt.subplot(gs[1, 0])
        ax1.plot(self.policy_loss)
        plt.title('Policy Loss')
        plt.xlabel('Update Number')

        ax2 = plt.subplot(gs[1, 1])
        ax2.plot(self.entropy_loss)
        plt.title('Entropy Loss')
        plt.xlabel('Update Number')

        ax3 = plt.subplot(gs[2, 0])
        ax3.plot(self.value_loss)
        plt.title('Value Loss')
        plt.xlabel('Update Number')

        ax4 = plt.subplot(gs[2, 1])
        ax4.plot(self.kl_div)
        plt.title('KL Divergence')
        plt.xlabel('Update Number')

        plt.tight_layout()
        plt.show()
        
    def calc_loss1(self, nn_inp, actions, rewards, advantages, beta=0.001):
        actions_t = torch.LongTensor(actions).to(self.network1.device)
        rewards_t = torch.FloatTensor(rewards).to(self.network1.device)
        advantages_t = torch.FloatTensor(advantages).to(self.network1.device)
        
        log_probs = self.network1.get_log_probs(nn_inp)
        log_probs =  log_probs.view(-1,10)
        log_prob_actions = advantages_t * log_probs[range(len(actions)), actions]
        policy_loss = -log_prob_actions.mean()
        
        action_probs, values = self.network1.predict(nn_inp)
        action_probs = action_probs.view(1,-1,10)
        values = values.view(1,-1,1)
        
        values = values[0]
        
        entropy_loss = -self.beta * (action_probs * log_probs).sum(dim=1).mean()
        value_loss = self.zeta * nn.MSELoss()(values.squeeze(-1), rewards_t)
        
        # Append values
        self.policy_loss_1.append(policy_loss.item())
        self.value_loss_1.append(value_loss.item())
        self.entropy_loss_1.append(entropy_loss.item())
        
        return policy_loss, entropy_loss, value_loss
        
    def calc_loss2(self, nn_inp, actions, rewards, advantages, beta=0.001):
        actions_t = torch.LongTensor(actions).to(self.network2.device)
        rewards_t = torch.FloatTensor(rewards).to(self.network2.device)
        advantages_t = torch.FloatTensor(advantages).to(self.network2.device)
        
        log_probs = self.network2.get_log_probs(nn_inp)
        log_probs =  log_probs.view(-1,10)
        log_prob_actions = advantages_t * log_probs[range(len(actions)), actions]
        policy_loss = -log_prob_actions.mean()
        
        action_probs, values = self.network2.predict(nn_inp)
        action_probs = action_probs.view(1,-1,10)
        values = values.view(1,-1,1)
        
        values = values[0]
        
        entropy_loss = -self.beta * (action_probs * log_probs).sum(dim=1).mean()
        value_loss = self.zeta * nn.MSELoss()(values.squeeze(-1), rewards_t)
        
        # Append values
        self.policy_loss_2.append(policy_loss.item())
        self.value_loss_2.append(value_loss.item())
        self.entropy_loss_2.append(entropy_loss.item())
        
        return policy_loss, entropy_loss, value_loss
        
    def update1(self, nn_inp, actions, rewards, advantages):
        self.network1.optimizer.zero_grad()
        policy_loss, entropy_loss, value_loss = self.calc_loss1(nn_inp, 
            actions, rewards, advantages)
        
        total_policy_loss = policy_loss - entropy_loss
        #self.total_policy_loss.append(total_policy_loss.item())
        total_policy_loss.backward(retain_graph=True)
        
        value_loss.backward()
        
        #total_loss = policy_loss + value_loss + entropy_loss
        #self.total_value_loss1.append(value_loss.item())
        self.network1.optimizer.step()
    
    def update2(self, nn_inp, actions, rewards, advantages):
        self.network2.optimizer.zero_grad()
        policy_loss, entropy_loss, value_loss = self.calc_loss2(nn_inp, 
            actions, rewards, advantages)
        
        total_policy_loss = policy_loss - entropy_loss
        #self.total_policy_loss2.append(total_policy_loss.item())
        total_policy_loss.backward(retain_graph=True)
        
        value_loss.backward()
        
        #total_loss = policy_loss + value_loss + entropy_loss
        #self.total_value_loss2.append(value_loss.item())
        self.network2.optimizer.step()
    
input1 = 0
input2 = 0
output1 = 0
output2 = 0
tgt = 0
tgt_list = [0]
checker = 0

ipg1 = []
ipg2 = []
opg1 = []
opg2 = []
t_l = []
bcount_list = []
int_bct = 1000
total_targets_achieved = []

total_policy_loss1_list = []
total_policy_loss2_list = []

total_value_loss1_list = []
total_value_loss2_list = []

total_rewards1_list = []
total_rewards2_list = []

LR = 1e-3

net1 = actorCriticNet(learning_rate=LR, n_hidden_layers=2, n_hidden_nodes=9)
net2 = actorCriticNet(learning_rate=LR, n_hidden_layers=2, n_hidden_nodes=9)
#net1 = torch.load('ma1_a2c_robot.pt')
#net2 = torch.load('ma2_a2c_robot.pt')
target1_list = [20,10,20]
target2_list = [20,20,10]
SLR = [1e-3, 1e-3, 1e-3]

a2c = A2C(net1,net2)
for j in range(len(target1_list)):
    
    a2c.network1.policy.apply(weights_init)
    a2c.network1.value.apply(weights_init)
    a2c.network2.policy.apply(weights_init)
    a2c.network2.value.apply(weights_init)
    
    for i in range(4):
        if i == 0:
            a2c.network1.optimizer = Adam([
                    {"params": a2c.network1.body.parameters(), "lr": SLR[j]},
                    {"params": a2c.network1.policy.parameters(), "lr": SLR[j]},
                    {"params": a2c.network1.value.parameters(), "lr": SLR[j]},
        lr = SLR[j],
                ])

            a2c.network2.optimizer = Adam([
                    {"params": a2c.network2.body.parameters(), "lr": SLR[j]},
                    {"params": a2c.network2.policy.parameters(), "lr": SLR[j]},
                    {"params": a2c.network2.value.parameters(), "lr": SLR[j]},
        lr = SLR[j],
                ])
        else:
            a2c.network1.optimizer = Adam([
                    {"params": a2c.network1.body.parameters(), "lr": 0},
                    {"params": a2c.network1.policy.parameters(), "lr": SLR[j]},
                    {"params": a2c.network1.value.parameters(), "lr": SLR[j]},
        lr = SLR[j],
                ])

            a2c.network2.optimizer = Adam([
                    {"params": a2c.network2.body.parameters(), "lr": 0},
                    {"params": a2c.network2.policy.parameters(), "lr": SLR[j]},
                    {"params": a2c.network2.value.parameters(), "lr": SLR[j]},
        lr = SLR[j],
                ])
        
        print('\nEpoch\t',i)
        in1, in2, ot1, ot2, t_o, best_count, best_action, best_state, total_rewards1, total_rewards2 = a2c.train(target1=target1_list[j], target2=target2_list[j], n_steps=12, num_episodes=1000, beta=1e-3, zeta=1e-4)