Autograd through torch.distributions.MultivariateNormal not working

Hi,

I was trying out some code related to reinforcement learning (TD3) when I thought that I could make the code a little bit nicer by replacing a few lines. However, this broke the entire autograd update and I am just wondering why. Below is the actor class with the forward pass that breaks the code.

class Actor(nn.Module):
   def __init__(self, state_dim, action_dim, action_std):
       super(Actor, self).__init__()

       self.net = nn.Sequential(
                       nn.Linear(state_dim, 256),
                       nn.ReLU(),
                       nn.Linear(256, 256),
                       nn.ReLU(),
                       nn.Linear(256, action_dim),
                       nn.Tanh())

       self.action_std = torch.full((action_dim,), action_std).to(device)

   def forward(self, state):
       action_mean = self.net(state)
       return action_mean + torch.randn_like(action_mean) * self.action_std

works fine, but

class Actor(nn.Module):
    def __init__(self, state_dim, action_dim, action_std):
        super(Actor, self).__init__()

        self.net = nn.Sequential(
                        nn.Linear(state_dim, 256),
                        nn.ReLU(),
                        nn.Linear(256, 256),
                        nn.ReLU(),
                        nn.Linear(256, action_dim),
                        nn.Tanh())

        self.action_var = torch.full((action_dim,), action_std*action_std).to(device)
        self.cov_mat = torch.diag(self.action_var).to(device)

    def forward(self, state):
        action_mean = self.net(state)
        dist = torch.distributions.MultivariateNormal(action_mean, self.cov_mat)
        return dist.sample()

fails and I have no clue why.

replace sample() with rsample(), documentation describes reparametrizable sampling I believe.

1 Like

thanks a lot, that fixed it!