Autograd for variance in normal distribution not working

Hi,

I am currently fiddling around with a PPO implementation where the loss depends on the entropy of a multivariate normal distribution. The entropy is calculated as follows

class Actor(nn.Module):
    def __init__(self, state_dim, action_dim, config):
        super(Actor, self).__init__()

        self.net = nn.Sequential(
                        nn.Linear(state_dim, 64),
                        nn.ReLU(),
                        nn.Linear(64, 64),
                        nn.ReLU(),
                        nn.Linear(64, action_dim),
                        nn.Tanh())

        self.action_std = torch.nn.Parameter(torch.ones(action_dim)*action_std).to(device)

    def evaluate(self, state, action):
        action_mean = self.net(state)
        dist = torch.distributions.normal.Normal(action_mean, self.action_std)
        action_logprobs = torch.sum(dist.log_prob(action), dim=1)
        dist_entropy = torch.sum(dist.entropy(), dim=1)
        return action_logprobs, dist_entropy

and the loss also depends on the return value “dist_entropy”. However during the backward pass the parameter “self.action_std” does not change at all. Any clue about the reason?

Hi,

The problem is that self.action_std is not the Paramter, the .to() op creates a new Tensor.
You should do

        self.action_std = torch.nn.Parameter(torch.ones(action_dim, device=device)*action_std)
1 Like

Hi,

thanks for the quick reply. Indeed, this modification solved the issue but leaves me a little bit puzzled. Out of curiosity I tried the following:

a = torch.nn.Parameter(torch.ones([1], device='cpu'))
b = torch.nn.Parameter(torch.ones([1], device='cuda'))
c = torch.nn.Parameter(torch.ones([1])).to('cpu')
d = torch.nn.Parameter(torch.ones([1])).to('cuda')

print(a, end="\n\n")
print(b, end="\n\n")
print(c, end="\n\n")
print(d, end="\n\n")

which gave the following output

Parameter containing:
tensor([1.], requires_grad=True)

Parameter containing:
tensor([1.], device='cuda:0', requires_grad=True)

Parameter containing:
tensor([1.], requires_grad=True)

tensor([1.], device='cuda:0', grad_fn=<CopyBackwards>)

a, b, and c make sense, but what happens to d? From my basic pytorch understanding (and the help of stackoverflow) a and b are straightforward. c would be copied to cpu but as it is already stored there, the .to(‘cpu’) has no effect.

Now to d: My understanding until now was, that torch.ones([1]) is a “normal” pytorch tensor that becomes a parameter (and thus differentiable) and gets moved to the GPU by the to('cuda') in the last step. Finally, d points to the address on the GPU. But the print output indicates that something different is going on… Btw, what does grad_fn=<CopyBackwards> mean?

a, b, and c make sense, but what happens to d?

Actually the special case is c here.
a and b are Parameters because you created them.
d is just a Tensor with a grad_fn: it is not a leaf anymore because it was created by a differentiable op
c is the special case because the Tensor is already on the cpu, then .to('cpu') is doing nothing and you still get the Parameter back.

1 Like