Laplace.sample() returns inf in half precision


I’m training a network in fp16 precision, where it has to sample from a laplace distribution.
I find that in fp16 the laplace.sample() method possibly returns torch.inf, but in fp32 this will not happen.
To reproduce the code:

import torch
import torch.distributions as dist
def test(r):
    count = 0
    for i in range(1000):
        if r.sample().isinf().any():
            count += 1
    return count

ra = dist.laplace.Laplace(torch.randn(100).cuda().half(),
print(test(ra)) # this returns non-zero value

Is it possible to solve this problem?
I know I can manually cast to fp32 for this operation, but I wonder if there exists any clean method.

We generally do not recommend using float16 manually but our mixed-precsion utils. instead which should cast to the needed wider dtype automatically.