How to avoid gradient vanish in pathwise derivative policy gradient

I want to train a pathwise derivative policy. But the output of my NN to be nan after about 5000 trainings. I guess because of the gradients’ vanish. How to solve my problem?

My NN is following:

class ACTOR_QVALUE(nn.Module):
    def __init__(self, input_size, hidden_size, action_size):
        super(ACTOR_QVALUE, self).__init__()
        self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=2,
                            bias=True, batch_first=True, dropout=0, bidirectional=False)
        self.hidden_mu1 = nn.Linear(hidden_size + action_size, 64)
        self.hidden_mu2 = nn.Linear(64, 32)
        self.hidden_mu3 = nn.Linear(32, 16)
        self.hidden_mu4 = nn.Linear(16, action_size)

        self.hidden_layer1 = nn.Linear(64 + action_size, 32)
        self.hidden_layer2 = nn.Linear(32, 8)
        self.hidden_layer3 = nn.Linear(8, 1)

    def forward(self, env_state, action_state, action=None, type='actor'):
        lstm_out, (h_n, c_n) = self.lstm(env_state)
        cat_layer =[:, -1, :], action_state), 1)
        mu = F.relu(self.hidden_mu1(cat_layer))
        if type == 'actor':
            mu = F.leaky_relu(self.hidden_mu2(mu))
            mu = F.leaky_relu(self.hidden_mu3(mu))
            mu = torch.softmax(self.hidden_mu4(mu), dim=1)
            return mu
            cat_layer =, action), 1)
            q = F.leaky_relu(self.hidden_layer1(cat_layer))
            q = F.leaky_relu(self.hidden_layer2(q))
            q = self.hidden_layer3(q)
            return q

where env_state(shape: 250*120) and action_state(shape: 1*8) is the state infomations, action means action whith shape 1*8. The NN output action if type=‘actor’, else output the q_value.

Nan is not caused by gradient vanishing, more oftenly it is caused by updating a part of your model continuously using extremely large gradients, like something divided by zero. then Nan will ocurr in all parameters.

Try clipping your gradient using torch.nn.utils.clip_grad_norm_or debug your model gradients using module.register_backward_hook

thank you for your reply.

Thank you for your advice.