Drichilet Log likelihood changes between 0 and 1

Hi I am training a LSTM + MHA network in which case I compute the negative log-likelihood of a distribution, which is parameterised by the output layer of the mode.

I found that across time batches, gradient jumps from + to -, and - to + all the time. I personally think that negative loss in this case represent that the existing parameters are certain about the likelihood of the observed data point.

i.e. the Drichilet is so concentrated around the area where data point observed that the PDF is higher than 1. so likliehood is greater than 1, so negative log likelihood flips to negative.

def get_loss(output, target, epsilon):

drichilet = Dirichlet(output)

loss = drichilet.log_prob(target, epsilon)
loss_sum = loss.sum(-1)

return -loss_sum

class Dirichlet(ExponentialFamily):

Creates a Dirichlet distribution parameterized by concentration :attr:`concentration`.


    >>> m = Dirichlet(torch.tensor([0.5, 0.5]))
    >>> m.sample()  # Dirichlet distributed with concentrarion concentration
    tensor([ 0.1046,  0.8954])

    concentration (Tensor): concentration parameter of the distribution
        (often referred to as alpha)
arg_constraints = {
    "concentration": constraints.independent(constraints.positive, 1)
support = constraints.simplex
has_rsample = True

def __init__(self, concentration, validate_args=None):
    if concentration.dim() < 1:
        raise ValueError(
            "`concentration` parameter must be at least one-dimensional."
    self.concentration = concentration
    batch_shape, event_shape = concentration.shape[:-1], concentration.shape[-1:]
    super(Dirichlet, self).__init__(
        batch_shape, event_shape, validate_args=validate_args

def expand(self, batch_shape, _instance=None):
    new = self._get_checked_instance(Dirichlet, _instance)
    batch_shape = torch.Size(batch_shape)
    new.concentration = self.concentration.expand(batch_shape + self.event_shape)
    super(Dirichlet, new).__init__(
        batch_shape, self.event_shape, validate_args=False
    new._validate_args = self._validate_args
    return new

def rsample(self, sample_shape=()):
    shape = self._extended_shape(sample_shape)
    concentration = self.concentration.expand(shape)
    return _Dirichlet.apply(concentration)

def log_prob(self, value, epsilon):
    if self._validate_args:
    return (
        (torch.log(value+epsilon) * (self.concentration - 1.0)).sum(-1)
        + torch.lgamma(self.concentration.sum(-1))
        - torch.lgamma(self.concentration).sum(-1)

Output Layer

“”" class linear_output(nn.Module):
def init(self, mha_output_dim, model_output_dim) → None:

    self.mha_output_dim = mha_output_dim
    self.model_output_dim = model_output_dim
    self.linear = nn.Linear(self.mha_output_dim, self.model_output_dim)

def forward(self, x):

    output = self.linear(x)
    output = torch.exp(output)

    return output

*** Adding Epsilon for numerical stability

I wanted to reach out to the community to see if my interpretation is correct