Nn.functional.one_hot taking exceptionally long

I’m training a GAN on a computing cluster with pytorch lightning and I’m trying to investigate what’s causing some bottlenecks with a profiler. As you can see, one_hot() is taking a very long time, with each call averaging 0.03 seconds:
Screen Shot 2021-03-04 at 9.10.15 PM

Here is the relevant code, sampling a conditional probability distribution from the generator:

    def sample_G(self, n) -> Tuple[Tensor, Tensor]:
        """
        Generates n generator samples. In this conditional GAN case we're sampling G with a uniform class label prior.

        @param n: Number of samples
        @return: Tensor of generator samples
        """
        z = self.sample_z(n)
        y = tr.floor(tr.rand(n, device=next(self.G.parameters()).device) * 10).long()
        y_one_hot = F.one_hot(y, num_classes=10).float()

        return self.G(z, y_one_hot)

Any help would be greatly appreciated.

Could you post the input shape to F.one_hot?

Here:
torch.Size([128])

It’s just the batch size

I cannot reproduce the 30ms using the PyTorch 1.8.0 binaries:

import torch
import torch.nn.functional as F
import time

y = torch.randint(0, 10, (128,), device='cuda')

# warmup
for _ in range(10):
    y_one_hot = F.one_hot(y, num_classes=10)

nb_iters = 1000
torch.cuda.synchronize()
t0 = time.time()
for _ in range(nb_iters):
    y_one_hot = F.one_hot(y, num_classes=10)
torch.cuda.synchronize()
t1 = time.time()
print('{:.3f}us/iter'.format(((t1 - t0)/nb_iters)*1e6))

Output

83.642us/iter

I’ll try re-running the code with time.time()

Getting between 100-6000 us per call. It looks like some calls are taking really long and others are not.
Screen Shot 2021-03-06 at 1.18.15 PM

Also I’m on pytorch 1.7.1

Here’s the code again:

def sample_G(self, n) -> Tuple[Tensor, Tensor]:
    """
    Generates n generator samples. In this conditional GAN case we're sampling G with a uniform class label prior.

    @param n: Number of samples
    @return: Tensor of generator samples
    """
    z = self.sample_z(n)
    y = tr.floor(tr.rand(n, device=next(self.G.parameters()).device) * 10).long()
    t0 = time.time()
    y_one_hot = F.one_hot(y, num_classes=10).float()
    t1 = time.time()
    print(t1 - t0)

    return self.G(z, y_one_hot)

Hi Robbie!

Please note the torch.cuda.synchronize() in @ptrblck’s sample
code. Without it you will get unreliable gpu timings.

(Also, please avoid posting screenshots of textual information. It
breaks accessibility, searchability, and copy-paste.)

Best.

K. Frank

2 Likes

Getting 100-200 us which appears to be pretty normal

(pid=14209) 172.13821411132812
(pid=14209) 195.50323486328125
(pid=14204) 175.23765563964844
(pid=14209) 151.3957977294922
(pid=14204) 189.30435180664062
(pid=14204) 145.43533325195312
(pid=14204) 207.1857452392578
(pid=14209) 212.4309539794922
(pid=14209) 130.8917999267578
(pid=14209) 186.44332885742188
(pid=14204) 169.27719116210938
(pid=14209) 171.661376953125
(pid=14204) 152.34947204589844
(pid=14204) 128.5076141357422
(pid=14204) 190.25802612304688
(pid=14209) 181.43653869628906
(pid=14209) 203.60946655273438
(pid=14209) 154.4952392578125
(pid=14204) 209.808349609375
(pid=14204) 244.140625
(pid=14209) 182.15179443359375
(pid=14204) 129.22286987304688
(pid=14204) 150.44212341308594
(pid=14209) 229.35867309570312