@KFrank Thank you very much for the explanation. That makes a lot of sense. I made a small change in the post - I was actually getting no improvement in loss with method 1 (it makes sense now since the grad is 0). I set `torch.autograd.set_detect_anomaly(True)`

in my training routine and used method 2. The loss actually becomes `nan`

and does not blow up (as mentioned earlier in the post). I get the following error:

`RuntimeError: Function 'SqrtBackward0' returned nan values in its 0th output.`

I think the error comes from the generation of the samples with method 2 that takes square root of the variances. The variable `variances`

is generated from a `Linear`

layer followed by `ReLU`

activation, so it is never negative. I am now not sure why the loss is becoming `nan`

.

To avoid this situation, I removed the `torch.sqrt()`

and just used:

`samples[t, ...] = means + torch.mul(variances, torch.randn_like(means))`

The network now learns to predict the standard deviations (although the variable is named `variances`

). When I do this, I get the following error:

`RuntimeError: Function 'LogBackward0' returned nan values in its 0th output.`

I am using a stochastic classification loss (eq 11 of this paper: link) that performs `torch.log`

operations. I have the following implementation of the loss function:

```
class HeteroscedasticLoss(torch.nn.Module):
def __init__(self, T=100, batch_size=128):
super().__init__()
self.T = T
self.batch_size = batch_size
def forward(self, outputs, targets):
outputs = torch.swapaxes(outputs, 1, 0) # convert to [batch_size, T, ..., num_classes]
probs = torch.zeros_like(outputs, device=targets.device)
i = 0
for x in outputs:
probs[i, ...] = torch.softmax(x, -1)
i = i + 1
probs = torch.mean(probs, 1)
loss = torch.nn.functional.nll_loss(torch.log(probs), targets)
return loss
```

I also used a different implementation (eq 12 of the paper):

```
class HeteroscedasticLoss(torch.nn.Module):
def __init__(self, T=100, batch_size=128):
super().__init__()
self.T = T
self.batch_size = batch_size
def forward(self, outputs, targets):
outputs = torch.swapaxes(outputs, 1, 0) # convert to [batch_size, T, ..., num_classes]
loss = 0
i = 0
for x in outputs:
lsm = torch.logsumexp(input=x, dim=-1) # [T,]
arg = torch.sum(torch.exp(x[..., targets[i].item()] - lsm)) / self.T # scalar
loss = loss + torch.log(arg)
i = i + 1
loss = loss / self.batch_size
return loss
```

In this case, the loss goes to `-inf`

. I have changed the optimizer and learning rate, but I get the same error - only a little later with small learning rates. I am not sure now why this is happening and how to address this. Is there any way to handle this issue?