Why torch.nn.functional.conv2d is a lot faster when group > 1?

Hi

It seems that torch.nn.functional.conv2d is faster when group > 1 even when the groups are just duplicated.

So for the following code, n_frames = 2 will produce 2 identical input_field and kernel and use groups = 2 for conv2d. The result should be the same as the n_frames = 1 with two duplicate channels.

From my understanding, n_frames = 2 will increase the computation. However, changing the n_frames of the following code from 1 to 2 will actually shorten the time (about 10 times) and also the results are slightly different. Does this mean that group > 1 will compromise the precision?

import torch
from time import time
device = 'cuda:1'
torch.manual_seed(0)

input_field = torch.rand(1, 1, 1080, 1920, dtype=torch.float32, device=device)
kernel = torch.rand(1, 1, 100, 100, dtype=torch.float32, device=device)


n_frames = 1
input_field = input_field.repeat(1, n_frames, 1, 1)
kernel = kernel.repeat(n_frames, 1, 1, 1) 

start_time = time()
torch.cuda.synchronize(device=device)
ret = torch.nn.functional.conv2d(
        input_field,
        kernel,
        padding='same',
        groups=n_frames,
    )
torch.cuda.synchronize(device=device)
end_time = time()

print(f"Time taken for conv2d with groups={n_frames}: {end_time - start_time:.4f} seconds")

Thank you!

I think I can see few potential improvements to have a fair comparison:

  • use torch.backends.cudnn.benchmark = True
  • warm up the gpu before each configuration run
  • run each configuration N (100 to 1000 will do) iterations
  • use the same input values (using .clone()) to ensure the same results (if running all tests in the same script)

I think the results you got most likely reflecting GPU warm up (if you run both tests in the same script) and pytorch, through cuDNN benchmarking different convolution algorithms and picking the best
which together gave you th e10x time difference

furthermore I would suggest using torch.profiler.profile for more precise measurements

Thanks so much for your reply!

From my testing, setting torch.backends.cudnn.benchmark = True is the key for accelerating the conv2d with one group and large kernel size. Only the first run of conv2d takes quite long but the following runs are all very fast:

Time taken for conv2d with groups=1, round=0: 13.8164 seconds
Time taken for conv2d with groups=1, round=1: 0.0100 seconds
Time taken for conv2d with groups=1, round=2: 0.0090 seconds
Time taken for conv2d with groups=1, round=3: 0.0090 seconds
Time taken for conv2d with groups=1, round=4: 0.0090 seconds
Time taken for conv2d with groups=1, round=5: 0.0090 seconds
Time taken for conv2d with groups=1, round=6: 0.0090 seconds
Time taken for conv2d with groups=1, round=7: 0.0090 seconds
Time taken for conv2d with groups=1, round=8: 0.0090 seconds
Time taken for conv2d with groups=1, round=9: 0.0090 seconds

And for conv2d with two groups, it seems that the default cudnn backend is already fast enough:

Time taken for conv2d with groups=2, round=0: 0.0200 seconds
Time taken for conv2d with groups=2, round=1: 0.0190 seconds
Time taken for conv2d with groups=2, round=2: 0.0180 seconds
Time taken for conv2d with groups=2, round=3: 0.0190 seconds
Time taken for conv2d with groups=2, round=4: 0.0180 seconds
Time taken for conv2d with groups=2, round=5: 0.0190 seconds
Time taken for conv2d with groups=2, round=6: 0.0180 seconds
Time taken for conv2d with groups=2, round=7: 0.0180 seconds
Time taken for conv2d with groups=2, round=8: 0.0190 seconds
Time taken for conv2d with groups=2, round=9: 0.0180 seconds

May I ask if it is possible to let pytorch remember the best cudnn backend for the case of one group? As I need to do many tests of this kind of conv2d with large kernel but each time I only need a single run.

I can somehow achieve this by using duplicate kernel with two groups but this seems not elegent. Thank you!

The codes I used for the above test, I did two seperate runs for n_frames=1 and n_frames=2 :

import torch
from time import time

torch.backends.cudnn.benchmark = True

device = 'cuda:1'
torch.manual_seed(0)

input_field = torch.rand(1, 1, 1080, 1920, dtype=torch.float32, device=device)
kernel = torch.rand(1, 1, 101, 101, dtype=torch.float32, device=device)


n_frames = 1
input_field = input_field.repeat(1, n_frames, 1, 1)  # Repeat the input field for n_frames
kernel = kernel.repeat(n_frames, 1, 1, 1)  # Repeat the sub-hologram phase for n_frames



for i in range(10):
    torch.cuda.synchronize(device=device)
    start_time = time()
    ret = torch.nn.functional.conv2d(
            input_field,
            kernel,
            padding='same',
            groups=n_frames,
        )
    torch.cuda.synchronize(device=device)
    end_time = time()
    print(f"Time taken for conv2d with groups={n_frames}, round={i}: {end_time - start_time:.4f} seconds")

given these results, it seems like most of the overhead time is due to benchmarking for the best kernel
(as far as I know pytorch doesn’t support kernel caching across sessions) which leaves you with two options

if you’re willing to give up the best kernel pick I would suggest setting torch.backends.cudnn.benchmark = False

or using torch.backends.cudnn.benchmarck = True you can run all your tests in one session

  • run once to pick the best kernel and warm up the GPU
  • run each test on it’s own (keep in mind that best kernel will persist as long as you’re using the same kernel shape, dtype, …) in an interactive environment

either using a notebook or by starting interactive python last thing in your script import IPython; IPyhton.embed()

Thanks for the help! I will go with the interactive option

1 Like