# Understanding log_prob for Normal distribution in pytorch

I’m currently trying to solve Pendulum-v0 from the openAi gym environment which has a continuous action space. As a result, I need to use a Normal Distribution to sample my actions. What I don’t understand is the dimension of the log_prob when using it :

``````import torch
from torch.distributions import Normal

means = torch.tensor([[0.0538],
[0.0651]])
stds = torch.tensor([[0.7865],
[0.7792]])

dist = Normal(means, stds)
a = torch.tensor([1.2,3.4])
d = dist.log_prob(a)
print(d.size())
``````

I was expecting a tensor of size 2 (one log_prob for each actions) but it output a tensor of size(2,2).

However, when using a Categorical distribution for discrete environment the log_prob has the expected size:

``````logits = torch.tensor([[-0.0657, -0.0949],
[-0.0586, -0.1007]])

dist = Categorical(logits = logits)
a = torch.tensor([1, 1])
print(dist.log_prob(a).size())
``````

give me a tensor a size(2).

Why is the log_prob for Normal distribution of a different size ?

1 Like

Hello sabeaussan!

You have created `dist` as a “batch” of two gaussian distributions.
(The first has `mean = 0.0538` and `std = 0.7865`, and the second
has `mean = 0.0651` and `std = 0.7792`.)

You get two `log_prob`s (one for each gaussian in your “batch”) for
each of the two elements of the tensor `a`, for a total of four `log_prob`s.

This is illustrated by some additions I made to the code you posted:

``````import torch
from torch.distributions import Normal
torch.__version__

torch.manual_seed (2020)

means = torch.FloatTensor([[0.0538],
[0.0651]])
stds = torch.FloatTensor([[0.7865],
[0.7792]])

dist = Normal(means, stds)
a = torch.FloatTensor([1.2,3.4])
d = dist.log_prob(a)
print(d.size())

disB = Normal (torch.FloatTensor ([0.0]), torch.FloatTensor ([1.0]))

dist.log_prob (a)
disB.log_prob (a)

dist.sample()
disB.sample()
dist.sample_n (3)
disB.sample_n (3)

b = torch.FloatTensor([5.6])
dist.log_prob (b)
disB.log_prob (b)
``````

And here are the results:

``````>>> import torch
>>> from torch.distributions import Normal
>>> torch.__version__
'0.3.0b0+591e73e'
>>>
>>> torch.manual_seed (2020)
<torch._C.Generator object at 0x000002A2CFE06630>
>>>
>>> means = torch.FloatTensor([[0.0538],
...         [0.0651]])
>>> stds = torch.FloatTensor([[0.7865],
...         [0.7792]])
>>>
>>> dist = Normal(means, stds)
>>> a = torch.FloatTensor([1.2,3.4])
>>> d = dist.log_prob(a)
>>> print(d.size())
torch.Size([2, 2])
>>>
>>> disB = Normal (torch.FloatTensor ([0.0]), torch.FloatTensor ([1.0]))
>>>
>>> dist.log_prob (a)

-1.7407 -9.7294
-1.7301 -9.8282
[torch.FloatTensor of size 2x2]

>>> disB.log_prob (a)

-1.6389
-6.6989
[torch.FloatTensor of size 2]

>>>
>>> dist.sample()

1.0269
-0.6832
[torch.FloatTensor of size 2x1]

>>> disB.sample()

1.5415
[torch.FloatTensor of size 1]

>>> dist.sample_n (3)

(0 ,.,.) =
-0.2670
0.7512

(1 ,.,.) =
0.0954
0.1236

(2 ,.,.) =
0.4295
-0.4615
[torch.FloatTensor of size 3x2x1]

>>> disB.sample_n (3)

-2.1489
-1.1463
-0.2720
[torch.FloatTensor of size 3x1]

>>>
>>> b = torch.FloatTensor([5.6])
>>> dist.log_prob (b)

-25.5424
-25.8980
[torch.FloatTensor of size 2x1]

>>> disB.log_prob (b)

-16.5989
[torch.FloatTensor of size 1]
``````

Best.

K. Frank

5 Likes