MultivariateNormal constructor with GPU tensors takes seconds to execute for large batch sizes

I’ve noticed that the MultivariateNormal constructor can take seconds to execute for large batch sizes if the arguments are GPU tensors. For the following code snippet, just running the constructor takes 2.5 seconds on a Titan Xp and only 15 milliseconds on the CPU. If a policy needs to create a distribution every time we run the forward pass it can become a significant overhead. What kind of processing is going on in the constructor that might be affecting the execution time?

import torch
import time

cuda = True # change this to False to see CPU time

if cuda:
   device = torch.device("cuda")
else:
   device = torch.device("cpu")

batch_size = 1000
event_size = 10

mean = torch.randn(batch_size, event_size, dtype=torch.float32)
mean = mean.to(device)

covariance = torch.eye(event_size, dtype=torch.float32)
covariance = covariance.unsqueeze(0).expand(batch_size, -1, -1)
covariance = covariance.to(device)

if cuda:
    torch.cuda.synchronize()
t0 = time.time()
torch.distributions.MultivariateNormal(mean, covariance)
if cuda:
    torch.cuda.synchronize()
t1 = time.time()
dt = t1 - t0

print("MultivariateNormal constructor took {} seconds".format(dt))

For anybody else with a similar issue, creating the distribution by explicitly specifying scale_tril is much more efficient. The GPU version is still a bit slower than the CPU version, but not by much. The documentation says that if a covariance matrix is passed to create the distribution, the corresponding triangular matrices are computed with a Cholesky decomposition. My guess is that the internal Cholesky decomposition on GPU tensors is much slower than on CPU tensors for some reason.