Expected kernel launch delay

Hi,
I am wondering what the expected time for kernel calls is. PyTorch kernel calls are asynchronous, so the GPU will do work while the CPU can already launch new kernels. I would expect something like the forward pass to very quick in python as long as there are no synchronization points. Obviously, when I need the result, I would have to wait for the CPU.
However, when I benchmarked this by simply measuring time in python for different operations it seemed like these were blocking. E.g.

-prediction.log_prob(target).mean()

takes longer than the forward pass of my model for some reason.

Any idea whether this is to expect?

How did you measure these operations? E.g. are you seeing a long kernel launch for this stand-alone operation or as part or a larger workload etc.?

Sorry for not being clear. Here are some examples that I ran on a GTX1080TI with PyTorch 1.8.
Maybe I am understanding something wrong about asynchronous execution.

from timeit import default_timer as timer

import torch
from torch import nn, jit
from torch.distributions import Normal
from tqdm import tqdm


class Timer:
    def __init__(self, name=None):
        self.name = name

        self.start = None
        self.end = None

    def __enter__(self):
        self.start = timer()

    def __exit__(self, *args, **kwargs):
        end = timer()
        duration = end - self.start

        print("DURATION:", self.name, duration)

Example 1: Simple multiplication

with Timer("kernel complete"):
    input_cpu = torch.rand(32, 512, 512).pin_memory()
    parameter_cpu = torch.rand(32, 512, 512, requires_grad=True).pin_memory()

    with Timer("kernel launch"):
        for i in tqdm(range(512)):
            input_gpu = input_cpu.to("cuda", non_blocking=True)
            parameter_gpu = parameter_cpu.to("cuda", non_blocking=True)

            result = input_gpu * parameter_gpu

    torch.cuda.synchronize()
DURATION: kernel launch 0.040197614999669895
DURATION: kernel complete 14.410601890999715

Example 2: Simple multiplication with backward pass

with Timer("kernel complete"):
    input_cpu = torch.rand(32, 512, 512).pin_memory()
    parameter_cpu = torch.rand(32, 512, 512, requires_grad=True).pin_memory()

    with Timer("kernel launch + cpu->gpu transfer"):
        with Timer("kernel launch"):
            for i in tqdm(range(512)):
                input_gpu = input_cpu.to("cuda", non_blocking=True)
                parameter_gpu = parameter_cpu.to("cuda", non_blocking=True)

                result = input_gpu * parameter_gpu

            result.mean().backward()

        input_gpu = input_cpu.to("cuda", non_blocking=True)
        parameter_gpu = parameter_cpu.to("cuda", non_blocking=True)

    torch.cuda.synchronize()
DURATION: kernel launch 5.375270492999334
DURATION: kernel launch + cpu->gpu transfer 5.3755872629990336
DURATION: kernel complete 6.740538608999486

Example 3: torch.distributions.Normal .sample() seems to block:

with Timer("kernel complete"):
    input_cpu = torch.rand(32, 512, 512).pin_memory()
    parameter_cpu = torch.rand(32, 512, 512, requires_grad=True).pin_memory()

    with Timer("kernel launch"):
        for i in tqdm(range(512)):
            input_gpu = input_cpu.to("cuda", non_blocking=True)
            parameter_gpu = parameter_cpu.to("cuda", non_blocking=True)

            result = Normal(input_gpu, 1).sample() * parameter_gpu

    torch.cuda.synchronize()
DURATION: kernel launch 5.955725089000225
DURATION: kernel complete 7.688505447999887

Example 4: torch.normal() launches quicker?

with Timer("kernel complete"):
    input_cpu = torch.rand(32, 512, 512).pin_memory()
    parameter_cpu = torch.rand(32, 512, 512, requires_grad=True).pin_memory()

    with Timer("kernel launch"):
        for i in tqdm(range(512)):
            input_gpu = input_cpu.to("cuda", non_blocking=True)
            parameter_gpu = parameter_cpu.to("cuda", non_blocking=True)

            result = (torch.normal(input_gpu, std=1)) * parameter_gpu

    torch.cuda.synchronize()
DURATION: kernel launch 1.862492633000329
DURATION: kernel complete 7.237628412999584

Example 5: log_prob also blocks?


with Timer("kernel complete"):
    input_cpu = torch.rand(32, 512, 512).pin_memory()
    parameter_cpu = torch.rand(32, 512, 512, requires_grad=True).pin_memory()

    with Timer("kernel launch"):
        for i in tqdm(range(512)):
            input_gpu = input_cpu.to("cuda", non_blocking=True)
            parameter_gpu = parameter_cpu.to("cuda", non_blocking=True)

            result = Normal(input_gpu, 1).log_prob(parameter_gpu)

    torch.cuda.synchronize()
DURATION: kernel launch 6.612539947000187
DURATION: kernel complete 8.380056750000222

Here, distributions are blocking because of scalar parameters, that are being wrapped in tensors (and copied to GPU blocking python) - this is avoidable with pre-created tensors (even better is to use simplified formulas for scale=1)

I’m not sure about backward() snippet, I think writing parameter_cpu.grad is blocking. But sync before starting backprop may also make sense…

Example 2 was choosen poorly. I updated it to better reflect my problem.

Update for example 5 with @googlebot explanation. Still far from what I would expect…

with Timer("kernel complete"):
    input_cpu = torch.rand(32, 512, 512).pin_memory()
    parameter_cpu = torch.rand(32, 512, 512, requires_grad=True).pin_memory()
    stddev = torch.ones_like(input_cpu, device="cuda")

    with Timer("kernel launch"):
        for i in tqdm(range(N_ITERATIONS)):
            input_gpu = input_cpu.to("cuda", non_blocking=True)
            parameter_gpu = parameter_cpu.to("cuda", non_blocking=True)

            result = Normal(input_gpu, stddev).log_prob(parameter_gpu)

    torch.cuda.synchronize()

With N_ITERATIONS = 512

DURATION: kernel launch 4.855967344999954
DURATION: kernel complete 7.614574080998864

With N_ITERATIONS = 2048

DURATION: kernel launch 23.592640275999656
DURATION: kernel complete 26.401106097000593

I suspect you’re using 1.8 and this change affects blocking :

Enable distribution validation by default for torch.distributions (#48743)

This may slightly slow down some models. Concerned users may disable validation by using torch.distributions.Distribution.set_default_validate_args(False) or by disabling individual distribution validation via MyDistribution(…, validate_args=False).

FYI, previous version doesn’t block in this snippet, unless validate_args=[“scale”] is added

1 Like

I just tested it with example 5 and does not seem to make a difference.

I have to correct myself. These are some results for example 5:

With set_default_validate_args(True) and N_ITERATIONS=32

DURATION: kernel launch 0.4117265930000258
DURATION: kernel complete 1.801716031000069

With set_default_validate_args(False) and N_ITERATIONS=32

DURATION: kernel launch 0.019914979000077437
DURATION: kernel complete 1.801976835000005

With set_default_validate_args(True) and N_ITERATIONS=512

DURATION: kernel launch 6.590126628999997
DURATION: kernel complete 8.072248363000085

With set_default_validate_args(False) and N_ITERATIONS=512

DURATION: kernel launch 4.870119396000064
DURATION: kernel complete 7.648874833000036

Is there something else I am missing? There seems to be a big difference in launch time depending on N_ITERATIONS.

I guess you’re overloading the cuda launch queue. Or garbage collection triggers cuda sync somehow.

1 Like