Hi,
I was trying out some code related to reinforcement learning (TD3) when I thought that I could make the code a little bit nicer by replacing a few lines. However, this broke the entire autograd update and I am just wondering why. Below is the actor class with the forward pass that breaks the code.
class Actor(nn.Module):
def __init__(self, state_dim, action_dim, action_std):
super(Actor, self).__init__()
self.net = nn.Sequential(
nn.Linear(state_dim, 256),
nn.ReLU(),
nn.Linear(256, 256),
nn.ReLU(),
nn.Linear(256, action_dim),
nn.Tanh())
self.action_std = torch.full((action_dim,), action_std).to(device)
def forward(self, state):
action_mean = self.net(state)
return action_mean + torch.randn_like(action_mean) * self.action_std
works fine, but
class Actor(nn.Module):
def __init__(self, state_dim, action_dim, action_std):
super(Actor, self).__init__()
self.net = nn.Sequential(
nn.Linear(state_dim, 256),
nn.ReLU(),
nn.Linear(256, 256),
nn.ReLU(),
nn.Linear(256, action_dim),
nn.Tanh())
self.action_var = torch.full((action_dim,), action_std*action_std).to(device)
self.cov_mat = torch.diag(self.action_var).to(device)
def forward(self, state):
action_mean = self.net(state)
dist = torch.distributions.MultivariateNormal(action_mean, self.cov_mat)
return dist.sample()
fails and I have no clue why.