Let’s look at this small code:
import torch
import torch.nn as nn
from torch.distributions import Normal
class Proba(nn.Module):
def __init__(self, size):
super(Proba, self).__init__()
self.mu = nn.Linear(size, size)
def forward(self, obs):
mu = self.mu(obs)
dist = Normal(mu, 0.2)
log_prob = dist.log_prob(obs)
return log_prob.mean()
proba = Proba(10)
optimizer = torch.optim.Adam(proba.parameters(), lr=1e-3)
# fixed observation to reconstruct:
obs = torch.rand(10)
# want to optimize the probability to reconstruct obs:
if __name__=="__main__":
for i in range(500):
loss = -proba(obs)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if i%100==0:
# should never be larger than 0:
print("log_prob of obs =", -loss.item())
The goal of this code is to reconstruct a signal “obs” by training a linear model that overfits an element-wise normal distribution with a fixed small standard deviation (0.2).
However, after ~500 iterations, the predicted log_prob starts being positive, while log_probabilites should never be positive.
Is something wrong in the code and the utilization of distributions.Normal
?