How does convolution scale so efficiently with batches?

I believe that conv2D is implemented with some cudnn tricks. But how does it scale so well with batches?

When I run a batch size of 1:

bs = 1
h = 128
w = 256
in_channels = 512
out_channels = 1024
kernel_size = (3,3)
stride = (1,1)
padding = (1,1)
dilation = (1,1)
cuda = True

test = Test(
    in_channels=in_channels,
    out_channels=out_channels,
    kernel_size=kernel_size,
    stride=stride,
    padding=padding,
    dilation=dilation
)

x = torch.ones(bs, in_channels, h, w).float()
weights = torch.ones(out_channels, in_channels, kernel_size[0], kernel_size[1])
bias = 100*torch.ones(out_channels)

print("Input:")
print(x.shape)

if cuda:
    test = test.cuda()
    x = x.cuda()
    weights = weights.cuda()
    bias = bias.cuda()

s = time.time()
pytorch = F.conv2d(x, 
    weights, 
    bias,
    stride=stride,
    padding=padding,
    dilation=dilation)
pytorch_time = time.time() - s

print('PyTorch:')
print(pytorch.shape)
print('PyTorch Time: ', pytorch_time)

I get:

>> Input:
>> torch.Size([1, 512, 128, 256])
>> PyTorch:
>> torch.Size([1, 1024, 128, 256])
>> PyTorch Time:  0.004613637924194336

Running the same code above with a batch size of 5 gives:

>> Input:
>> torch.Size([5, 512, 128, 256])
>> PyTorch:
>> torch.Size([5, 1024, 128, 256])
>> PyTorch Time:  0.0049724578857421875

The wall-clock time is negligibly different. Looking at the code in the THCUNN library (granted, it’s not cudnn), it appears that the batches are processed in a for-loop (see here). This is the same way Caffe handles it too.

I can’t find any code to suggest that multiple batches are processed (im2col+gemm) together. Everything I find seems to suggest each batch is handled sequentially. Given this, how is there no linear scaling to the running time?

The thing is that the cuda API is asynchronous so the only thing you measure here is the time to launch the job, not how long it takes to run on the GPU. You need to synchronise:

# Make sure nothing was still running on the GPU
torch.cuda.synchronize()
s = time.time()
pytorch = F.conv2d(x, 
    weights, 
    bias,
    stride=stride,
    padding=padding,
    dilation=dilation)
# Wait for the job to actually be done
torch.cuda.synchronize()
pytorch_time = time.time() - s

Ahh, thank you, that makes sense.

Now the real question is: how does PyTorch launch the jobs so efficiently? For the learning experience, I’ve re-implemented convolution with my own version of im2col. Using mine, the kernel launch time is an order of magnitude slower, but the actual running time (after synchronizing) is only 1.25x slower. Is there some funky stuff going on under the hood to launch the job?

There is nothing special, just that the whole convolution is going directly going to cpp and then launches a single cudnn kernel.
Your version with im2col most certainly launch few kernels which result in an increased launch time.

1 Like

Is there somewhere I can find more detail about this? You’re right about launching a few kernels. Each batch launches a kernel for the im2col computation and a kernel for the gemm (done through cuBLAS). However, when I profile the loop, I see that most of the launch time is taken up by the initial im2col call (i.e. in the first batch). Synchronizing with cudaDeviceSynchronize() after each call produces these times:

batch: 0
im2col: 0.000030s
gemm: 0.000026s
batch: 1
im2col: 0.000012s
gemm: 0.000008s
batch: 2
im2col: 0.000008s
gemm: 0.000007s
batch: 3
im2col: 0.000008s
gemm: 0.000007s
batch: 4
im2col: 0.000008s
gemm: 0.000007s
batch: 5
im2col: 0.000008s
gemm: 0.000007s

I’d presumed the speed up for the later batches was due to caching the columns (for im2col) and the weights (for gemm), but I don’t know how I’d do this with a single kernel. Or this this something that cudnn makes possible? As in, does cudnnConvolutionForward implement the entire operation on a single kernel?

The very first batch is always slower as many things are lazily initialized in cuda. So the first one will actually potentially do extra initialization work. Hopefully you should be able to ignore these timings for any real workload.

I’m not sure about the cudnn internals. You should be able to use nvvp to see that. But I think yes it does everything in a single kernel (but that may vary depending on the algorithm that was selected).

1 Like

I think you’re right. I was trying my implementation with varying kernel sizes and strides. I found with a standard (3,3) kernel and unit strides, I’m about 30% off the speed of the cudnn convolution used by Pytorch. However, for non-square kernels and non-unit strides I am roughly equal. In general, though, I’m finding square kernels with unit strides are unbeatable.

I think it heavily depends on the algorithm. My guess is that the unit stride, (3,3) convolution is particularly optimized as it’s simple and very common. I think for that operation, an FFT convolution is faster than im2col so the selector is likely choosing that algorithm, which I am pretty sure can be implemented on a single kernel.

Thanks for the help making sense of this!