I’ve noticed that the MultivariateNormal
constructor can take seconds to execute for large batch sizes if the arguments are GPU tensors. For the following code snippet, just running the constructor takes 2.5 seconds on a Titan Xp and only 15 milliseconds on the CPU. If a policy needs to create a distribution every time we run the forward pass it can become a significant overhead. What kind of processing is going on in the constructor that might be affecting the execution time?
import torch
import time
cuda = True # change this to False to see CPU time
if cuda:
device = torch.device("cuda")
else:
device = torch.device("cpu")
batch_size = 1000
event_size = 10
mean = torch.randn(batch_size, event_size, dtype=torch.float32)
mean = mean.to(device)
covariance = torch.eye(event_size, dtype=torch.float32)
covariance = covariance.unsqueeze(0).expand(batch_size, -1, -1)
covariance = covariance.to(device)
if cuda:
torch.cuda.synchronize()
t0 = time.time()
torch.distributions.MultivariateNormal(mean, covariance)
if cuda:
torch.cuda.synchronize()
t1 = time.time()
dt = t1 - t0
print("MultivariateNormal constructor took {} seconds".format(dt))