Inconsistent performance with large kernal convolutions between computers

Im having some troubles with large convolution kernals having extremly different performance between environments and computers.

Im using the following benchmark, which applies a large gaussian kernal two a large random array.

import torch
import torch.nn.functional as F
import numpy as np
import time

def gaussian_kernel(size, sigma):

    size = int(size) // 2
    x, y = np.mgrid[-size:size+1, -size:size+1]
    g = np.exp(-(x**2 + y**2) / (2 * sigma**2))
    return g / g.sum()

def conv_helper(in_arr, conv_kernel, device, padding):
    start_time = time.time()

    # Convert input array to PyTorch tensor and move to device
    torch_in_arr = torch.from_numpy(in_arr).float().unsqueeze(0).unsqueeze(0)
    end_time = time.time()
    print(f"Time taken to convert input array to tensor: {end_time - start_time:.6f} seconds")

    start_time = time.time()
    torch_in_arr = torch_in_arr.to(device)
    end_time = time.time()
    print(f"Time taken to move input tensor to device: {end_time - start_time:.6f} seconds")

    start_time = time.time()
    # Convert kernel to PyTorch tensor and move to device
    torch_kernel = torch.from_numpy(np.array(conv_kernel)).float().unsqueeze(0).unsqueeze(0)
    end_time = time.time()
    print(f"Time taken to convert kernel to tensor: {end_time - start_time:.6f} seconds")

    start_time = time.time()
    torch_kernel = torch_kernel.to(device)
    end_time = time.time()
    print(f"Time taken to move kernel to device: {end_time - start_time:.6f} seconds")

    torch.cuda.synchronize() 
    start_time = time.time()
    # Perform convolution on the GPU
    out_arr_gpu = F.conv2d(torch_in_arr, torch_kernel, padding=padding)
    torch.cuda.synchronize() 
    end_time = time.time()
    print(f"Time taken for convolution operation: {end_time - start_time:.6f} seconds")

    torch.cuda.synchronize() 
    start_time = time.time()
    # Move result back to CPU asynchronously
    out_arr_gpu = out_arr_gpu.to('cpu', non_blocking=True)
    torch.cuda.synchronize()  # Ensure all GPU operations are completed
    end_time = time.time()
    print(f"Time taken to move result back to CPU: {end_time - start_time:.6f} seconds")

    start_time = time.time()
    # Convert to numpy array
    out_arr_cpu = out_arr_gpu.numpy().squeeze()
    end_time = time.time()
    print(f"Time taken to convert to numpy array: {end_time - start_time:.6f} seconds")

    return out_arr_cpu

# Example usage
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
input_size = 4000
kernel_size = 151
sigma = 25

# Create a large random input array
input_array = np.random.randn(input_size, input_size)

# Create a Gaussian kernel
gaussian_k = gaussian_kernel(kernel_size, sigma)

# Perform convolution with different paddings
padding_same = kernel_size // 2  # Same padding

print("Running with same padding:")
output_array_same = conv_helper(input_array, gaussian_k, device, padding_same)

I have tested this in multiple enironments:

  1. PC1 1x3090, pytorch 2.1.2, Cuda 12.2, Driver 535…
  2. PC1 but in a NVCR pytorch container 22.05 (torch 1.12)
  3. PC2 2x4090, pytorch 2.1.2, Cuda 12.2, Driver 535…
  4. PC2 but in a NVCR pytorch container 22.05 (torch 1.12)

1,2,4 all get a total runtime of 2 ish seconds with the convolution operation taking less than 1 second. But specifically 3 has the convolution operation take 20 seconds.

Any further ways i could narrow down why setup 3 is so much slower?

I have also tried different pytorch environments, manually setting a gpu device to a numa node (both setups are AMD threadrippers, 3960x and 7960x respectivly), disabling NCCL_P2P…

It seems like it must have to be some cuda or torch system variable, as the performance is as expected in the docker container

Cheers.

More testing.

PC1 has a conv time of .5 with nvcr pytorch container 24.01,
(CUDA 12.3.2, torch 2.2.0a0+81ea7a4, TensorRT 8.6.1.6)

But PC2 has the same 20 second time with that same container

Furthermore,
When using python versions < 2 (tested multiple containers on both pcs), they have the same performance.

You could profile the workload via e.g. Nsight Systems to narrow down which operations are causing the slowdown. Assuming you are right and it’s a conv, you could check the cuDNN version and update it to the latest one.

I will try both thanks,

Can cudnn version on the host effect torch within a container?

If you mount it to the container to replace the ones shipped in the container and could thus use the host version. By default, docker should not be able to use your host cuDNN.

Here are two screen shots from Nsight


These are from PC1 that is performing way better


Additionally, reinstalling cuDNN made no difference.

Getting the same behaviour on a fresh install of ubuntu 22.04, but i think i may have solved whats going on.

So interestingly looking at nSight systems and running the same conv operation 3x with torch.backends.cudnn.benchmark=True. The method that its picking on the first system and after optimizing with benchmark is ‘conv2d_grouped_direct_kernel’, while PC2 by default selects an operation that is much slower for this kernal/matrix size. Question now is, is there a way to force the use of the different method under the hood?


For static shapes you could use torch.backends.cudnn.benchmark = True to allow cuDNN to profile kernels before selecting them. There is no way to manually select algorithms.