with the training process , the actor_loss and critic_loss increase rather than decrease.
and here is my actor net and critic net with hidden size1=512,hidden size2=1024,hidden size3=512,hidden size4=256:
class ActorNet(nn.Module):
def __init__(self):
super(ActorNet, self).__init__()
self.input_size = 10 + K * 2
self.output_size = 1 + 1
self.fc1 = nn.Linear(self.input_size, HIDDEN_SIZE_1)
self.fc2 = nn.Linear(HIDDEN_SIZE_1, HIDDEN_SIZE_2)
self.fc3 = nn.Linear(HIDDEN_SIZE_2, HIDDEN_SIZE_3)
self.fc4 = nn.Linear(HIDDEN_SIZE_3, HIDDEN_SIZE_4)
self.fc5 = nn.Linear(HIDDEN_SIZE_4, self.output_size)
# init weight
nn.init.xavier_normal_(self.fc1.weight)
nn.init.constant_(self.fc1.bias, 0)
nn.init.xavier_normal_(self.fc2.weight)
nn.init.constant_(self.fc2.bias, 0)
nn.init.xavier_normal_(self.fc3.weight)
nn.init.constant_(self.fc3.bias, 0)
nn.init.xavier_normal_(self.fc4.weight)
nn.init.constant_(self.fc4.bias, 0)
nn.init.xavier_normal_(self.fc5.weight)
nn.init.constant_(self.fc5.bias, 0)
def forward(self, x):
x = self.fc1(x)
x = torch.relu(x)
x = self.fc2(x)
x = torch.relu(x)
x = self.fc3(x)
x = torch.relu(x)
x = self.fc4(x)
x = torch.relu(x)
x = self.fc5(x)
x[0][0] = torch.sigmoid(x[0][0])
x[0][1] = torch.tanh(x[0][1])
return x
class CriticNet(nn.Module):
def __init__(self):
super(CriticNet, self).__init__()
self.input_size = 10 + K * 2 + 2
self.output_size = 1
self.fc1 = nn.Linear(self.input_size, HIDDEN_SIZE_1)
self.fc2 = nn.Linear(HIDDEN_SIZE_1, HIDDEN_SIZE_2)
self.fc3 = nn.Linear(HIDDEN_SIZE_2, HIDDEN_SIZE_3)
self.fc4 = nn.Linear(HIDDEN_SIZE_3, HIDDEN_SIZE_4)
self.fc5 = nn.Linear(HIDDEN_SIZE_4, self.output_size)
# init weight
nn.init.xavier_normal_(self.fc1.weight)
nn.init.constant_(self.fc1.bias, 0)
nn.init.xavier_normal_(self.fc2.weight)
nn.init.constant_(self.fc2.bias, 0)
nn.init.xavier_normal_(self.fc3.weight)
nn.init.constant_(self.fc3.bias, 0)
nn.init.xavier_normal_(self.fc4.weight)
nn.init.constant_(self.fc4.bias, 0)
nn.init.xavier_normal_(self.fc5.weight)
nn.init.constant_(self.fc5.bias, 0)
def forward(self, state,action):
x = torch.cat([state, action], 1)
x = self.fc1(x)
x = torch.relu(x)
x = self.fc2(x)
x = torch.relu(x)
x = self.fc3(x)
x = torch.relu(x)
x = self.fc4(x)
x = torch.relu(x)
x = self.fc5(x)
return x
and the actor loss and the critic loss is defined by :
policy_Q = self.critic(state_batch, self.actor(state_batch))
actor_loss = -policy_Q.mean()
next_action_batch = self.target_actor(next_state_batch)
target_Q = self.target_critic(next_state_batch,next_action_batch.detach())
label_Q = reward_batch + GAMMA * target_Q
policy_Q_ = self.critic(state_batch, action_batch)
#critic_loss = ((label_Q - policy_Q_) ** 2).mean()
critic_loss = self.value_criterion(label_Q, policy_Q_.detach())
where the value_criterion is equal to mse loss