Sudden explosion of the training loss

Hi I am using the output of my network to parameterise a Dirichlet distribution, however the loss function explodes sky high, when their is a point like [1,0,0] coming in as the observation.

I wondered if the statistical modelling community knows any methods to the resolution of this problem?

image