Error while sampling from Dirichlet distribution

Traceback (most recent call last):
  File "", line 167, in <module>
  File "", line 142, in test
    _, predicted = predict(outputs).max(1)
  File "", line 43, in predict
    preds = dist.sample()
  File "/u/sahariac/.local/lib/python3.6/site-packages/torch/distributions/", line 97, in sample
    return self.rsample(sample_shape)
  File "/u/sahariac/.local/lib/python3.6/site-packages/torch/distributions/", line 68, in rsample
    return _Dirichlet.apply(concentration)
  File "/u/sahariac/.local/lib/python3.6/site-packages/torch/distributions/", line 28, in forward
    x = _dirichlet_sample_nograd(concentration)
  File "/u/sahariac/.local/lib/python3.6/site-packages/torch/distributions/", line 12, in _dirichlet_sample_nograd
    probs = torch._standard_gamma(concentration)
RuntimeError: _standard_gamma is not implemented for type torch.cuda.FloatTensor

Which version of PyTorch are you running on?

For PyTorch version: 0.4.1 when I run the following I get no error

>>> import torch.distributions.dirichlet as d
>>> m = d.Dirichlet(torch.tensor([0.5, 0.5], device='cuda'))
>>> m.sample()
tensor([0.5290, 0.4710], device='cuda:0')

Could you provide some code that causes this error?

If you are on 0.4 you can try upgrading to 0.4.1 and see if that fixes your error.
You can do that by running conda update pytorch -c pytorch

Thanks. Updating to 0.4.1 worked for me.