Performance of sampling from random distributions

Hello,

I am running a training algorithm and in one step, I need to perform Sampling from a Gaussian distribution with a given standard deviation.

Each one epoch in my training takes around 5 seconds if I don’t perform the sampling step.
However, If I do the sampling, it becomes too slow (1 epoch = 120 seconds)!!.

I am doing it using .
Before starting the training, I create a normal distribution object.

import torch.distributions as tdist
normal_dist = tdist.Normal(0.0, sigma ) 

then within my training loop:


noise = normal_dist.sample(w.size())

Any hints why is this happening and if there is a more efficient way?

EDIT:

  • torch.randn is also as slow.
  • I did comparison between tensorflow vs pytorch performance on random sampling, when the shape of the output noise small PyTorch tends to be faster, but if we are sampling big tensors, TensorFlow is way faster and Pytorch becomes too slow.
  • The size of the output in my epxeriment is 1024x128x128.
1 Like

Could you try to time the noise creation?
Are you using the CPU or GPU?
If you’re using GPU, you should synchronize before starting and stopping the timer:

torch.cuda.synchronize()
t0 = time.time()
noise = normal_dist.sample(w.size())
torch.cuda.synchronize()
t1 = time.time()
1 Like

I did a small benchmark script to time the noise creation for PyTorch vs TensorFlow.

Noise size = (1024, 128, 128) -
PyTorch Time = 0.09 Sec
TF Time= 0.03 Sec

Noise size = (1024, 256, 256) -
PyTorch Time = 0.398 Sec
TF Time= 0.118 Sec

Since noise sampling torch.randn or normal_dist.sample happens on CPU, then we move the results to GPU. So I measured the time of generation of random noise without actually moving them to GPU.

1 Like

Similar issue for

distrib = MultivariateNormal(loc=torch.zeros(input_dimensionality), covariance_matrix=torch.diag(torch.ones(input_dimensionality[0])))
distrib.sample((N,)).cuda()

for N=5 and input_dimensionality=16K

are there any updates on this?

Is there any update to this? I am also finding pytorch sampling slow. I am using torch.distributions.Normal(0, 1).sample.
Edit: I found torch.empty().normal_() faster than torch.distributions.Normal(0, 1).sample, so just managing with this now.