Variable running time for conv2d

Hello.

For teaching purposes, I have designed a small network using conv2d for computing a classical Gabor descriptor for a toy QBE system. One of the goals is to show the speedup using GPUs. I have inserted time measurement and, with GPU, it appears that the processing time of the first batch (about 250 ms) is much longer than that of the subsequent ones (about 850 µs) which I assume to be due to the fact that some allocations and initializations are performed. The second batch is also a bit slower (about 1250 µs) and then, for about 50 iterations, the batch run time is stable at about 850 µs, which would correspond to about 80 Tflops if conv2d was implemented using the basic convolution formula, which indeed is not the case as the GTX 1080 Ti used has a maximum capability of only 11 Tflops.

What is strange and disappointing is that, after 58 batches, the batch processing time jumps to about 35 ms, which is about 40 time more. I checked whether there were any memory leak issue but there seems to be none.

Any idea of why such a slow down appear and how to avoid it?
The code and the output log are given below.

Code:

import torch
import torchvision

import matplotlib.pyplot as plt
import numpy as np
import math

import torch.nn as nn
import torch.nn.functional as F

import time

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#device = "cpu"
print(device)

def gabor_kernel_init_seq(weight,lambd = 16.0, nt = 8, n = 0, sl = 0.7, st = 1.4, nl = 4.0):
    t0 = time.time()
    if n <= 0: n = 1+2*int(nl*lambd)
    gl = -0.5/(sl*sl)
    gt = -0.5/(st*st)
    for t in range (0, nt):
        theta = t*math.pi/nt
        c = math.cos(theta)/lambd
        s = math.sin(theta)/lambd
        x0 = 0.5*(n-1)*(c+s)
        y0 = 0.5*(n-1)*(c-s)
        sc = 1.0/(2*math.pi*sl*st*lambd*lambd)
        for y in range (0,n):
            for x in range (0,n):
                xr = c*x+s*y-x0  # centering, rotation and scaling
                yr = c*y-s*x-y0  # centering, rotation and scaling
                a = 2.0*math.pi*xr  # wave phase
                g = sc*math.exp(gl*xr*xr+gt*yr*yr)  # Gaussian amplitude
                weight[t+0*nt, 0, y, x] = g*math.cos(a)
                weight[t+1*nt, 0, y, x] = g*math.sin(a)
    print(time.time()-t0)

def gabor_kernel_init(weight,lambd = 16.0, nt = 8, n = 0, sl = 0.8, st = 1.6, nl = 4.0):
    t0 = time.time()
    if n <= 0: n = 1+2*int(nl*lambd)
    gl = -0.5/(sl*sl)
    gt = -0.5/(st*st)
    x = torch.tensor(range(n)).unsqueeze_(0).expand(n,n)  # x coordinate
    y = torch.tensor(range(n)).unsqueeze_(1).expand(n,n)  # y coordinate
    for t in range (0, nt):
        theta = t*math.pi/nt
        c = math.cos(theta)/lambd
        s = math.sin(theta)/lambd
        x0 = 0.5*(n-1)*(c+s)
        y0 = 0.5*(n-1)*(c-s)
        sc = 1.0/(2*math.pi*sl*st*lambd*lambd)
        xr = c*x+s*y-x0  # centering, rotation and scaling
        yr = c*y-s*x-y0  # centering, rotation and scaling
        a = 2.0*math.pi*xr  # wave phase
        g = sc*torch.exp(gl*xr*xr+gt*yr*yr)  # Gaussian amplitude
        weight[t+0*nt, 0] = g*torch.cos(a)
        weight[t+1*nt, 0] = g*torch.sin(a)
    print("kernel init %dx%d: %.2f ms"% (n, n, 1000*(time.time()-t0)))

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv0 = nn.Conv2d(1, 16, 25, bias=False)
        self.conv1 = nn.Conv2d(1, 16, 49, bias=False)
        self.conv2 = nn.Conv2d(1, 16, 97, bias=False)
        self.conv3 = nn.Conv2d(1, 16, 193, bias=False)
        gabor_kernel_init(self.conv0.weight,lambd = 3.0)
        gabor_kernel_init(self.conv1.weight,lambd = 6.0)
        gabor_kernel_init(self.conv2.weight,lambd = 12.0)
        gabor_kernel_init(self.conv3.weight,lambd = 24.0)
 
    def forward(self, x):
        x0 = torch.squeeze(F.adaptive_avg_pool2d(self.conv0(x)**2, (1, 1)),3)/1
        x1 = torch.squeeze(F.adaptive_avg_pool2d(self.conv1(x)**2, (1, 1)),3)/2
        x2 = torch.squeeze(F.adaptive_avg_pool2d(self.conv2(x)**2, (1, 1)),3)/4
        x3 = torch.squeeze(F.adaptive_avg_pool2d(self.conv3(x)**2, (1, 1)),3)/8
        y = torch.cat((x0, x1, x2, x3), dim = 2)
        y1, y2 = torch.split(y, 8, dim = 1)
        return y1+y2

net = Net().to(device)

from PIL import Image
import torchvision.transforms.functional as TF
x = torch.cat((255*TF.to_tensor(Image.open('cat.jpg')).to(device).unsqueeze_(0),
               255*TF.to_tensor(Image.open('houses.jpg')).to(device).unsqueeze_(0),
               255*TF.to_tensor(Image.open('bur0.jpg')).to(device).unsqueeze_(0),
               255*TF.to_tensor(Image.open('bur1.jpg')).to(device).unsqueeze_(0)))
print(x.shape)

nops = ((64.0*64*193*193)+(160.0*160*97*97)+(208.0*208*49*49)+(232*232*25*25))*16.0*2.0
print("Gflops/image: %.3f (direct calculation)"% (nops/1000000000.0))

tt = 0
for i in range(100): 
    t0 = time.time()
    y = net(x)
    dt = time.time()-t0
    print("iter %2d: %.2f µs, %.1f Gflops"% (i, dt*1000000, 4*nops/dt/1000000000))
    if (i > 1): tt += dt
    #print(y[3, 7, 3])

#print(y.shape)
dt = tt/98
print("Average: %.2f µs, %.1f Gflops"% (dt*1000000, 4*nops/dt/1000000000))

#check for memory leaks
import gc
for obj in gc.get_objects():
    try:
        if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
            print(type(obj), obj.size())
    except:
        pass

#print(y)

Log:

cuda:0
kernel init 25x25: 4.76 ms
kernel init 49x49: 1.69 ms
kernel init 97x97: 3.01 ms
kernel init 193x193: 17.64 ms
torch.Size([4, 1, 256, 256])
Gflops/image: 16.991 (direct calculation)
iter  0: 313055.28 µs, 217.1 Gflops
iter  1: 1199.25 µs, 56671.3 Gflops
iter  2: 838.28 µs, 81074.1 Gflops
iter  3: 817.06 µs, 83179.6 Gflops
iter  4: 819.44 µs, 82937.6 Gflops
iter  5: 812.77 µs, 83618.8 Gflops
iter  6: 809.43 µs, 83963.6 Gflops
iter  7: 812.77 µs, 83618.8 Gflops
iter  8: 810.38 µs, 83864.8 Gflops
iter  9: 826.84 µs, 82196.2 Gflops
iter 10: 811.58 µs, 83741.6 Gflops
iter 11: 807.05 µs, 84211.7 Gflops
iter 12: 809.67 µs, 83938.9 Gflops
iter 13: 805.85 µs, 84336.3 Gflops
iter 14: 819.21 µs, 82961.7 Gflops
iter 15: 811.58 µs, 83741.6 Gflops
iter 16: 807.52 µs, 84162.0 Gflops
iter 17: 803.23 µs, 84611.6 Gflops
iter 18: 811.82 µs, 83717.0 Gflops
iter 19: 816.58 µs, 83228.2 Gflops
iter 20: 802.99 µs, 84636.7 Gflops
iter 21: 882.39 µs, 77021.5 Gflops
iter 22: 809.19 µs, 83988.4 Gflops
iter 23: 812.53 µs, 83643.3 Gflops
iter 24: 820.88 µs, 82793.1 Gflops
iter 25: 804.19 µs, 84511.3 Gflops
iter 26: 807.52 µs, 84162.0 Gflops
iter 27: 809.19 µs, 83988.4 Gflops
iter 28: 804.90 µs, 84436.2 Gflops
iter 29: 810.38 µs, 83864.8 Gflops
iter 30: 813.48 µs, 83545.3 Gflops
iter 31: 808.95 µs, 84013.1 Gflops
iter 32: 801.32 µs, 84813.0 Gflops
iter 33: 805.85 µs, 84336.3 Gflops
iter 34: 812.77 µs, 83618.8 Gflops
iter 35: 802.99 µs, 84636.7 Gflops
iter 36: 812.53 µs, 83643.3 Gflops
iter 37: 808.24 µs, 84087.5 Gflops
iter 38: 818.73 µs, 83010.1 Gflops
iter 39: 808.24 µs, 84087.5 Gflops
iter 40: 806.81 µs, 84236.6 Gflops
iter 41: 806.09 µs, 84311.3 Gflops
iter 42: 808.72 µs, 84037.9 Gflops
iter 43: 813.96 µs, 83496.3 Gflops
iter 44: 807.29 µs, 84186.8 Gflops
iter 45: 809.91 µs, 83914.2 Gflops
iter 46: 789.88 µs, 86041.8 Gflops
iter 47: 814.20 µs, 83471.9 Gflops
iter 48: 823.02 µs, 82577.2 Gflops
iter 49: 808.95 µs, 84013.1 Gflops
iter 50: 810.15 µs, 83889.5 Gflops
iter 51: 808.00 µs, 84112.3 Gflops
iter 52: 810.15 µs, 83889.5 Gflops
iter 53: 817.06 µs, 83179.6 Gflops
iter 54: 812.05 µs, 83692.5 Gflops
iter 55: 808.48 µs, 84062.7 Gflops
iter 56: 813.25 µs, 83569.8 Gflops
iter 57: 812.29 µs, 83667.9 Gflops
iter 58: 32801.15 µs, 2072.0 Gflops
iter 59: 39488.08 µs, 1721.1 Gflops
iter 60: 39482.12 µs, 1721.4 Gflops
iter 61: 37758.11 µs, 1800.0 Gflops
iter 62: 36569.83 µs, 1858.4 Gflops
iter 63: 35660.03 µs, 1905.9 Gflops
iter 64: 36850.45 µs, 1844.3 Gflops
iter 65: 38070.92 µs, 1785.2 Gflops
iter 66: 32813.07 µs, 2071.2 Gflops
iter 67: 32968.52 µs, 2061.4 Gflops
iter 68: 32866.95 µs, 2067.8 Gflops
iter 69: 33593.89 µs, 2023.1 Gflops
iter 70: 32798.77 µs, 2072.1 Gflops
iter 71: 31944.99 µs, 2127.5 Gflops
iter 72: 32559.63 µs, 2087.3 Gflops
iter 73: 31848.67 µs, 2133.9 Gflops
iter 74: 32484.05 µs, 2092.2 Gflops
iter 75: 31897.54 µs, 2130.7 Gflops
iter 76: 31959.06 µs, 2126.6 Gflops
iter 77: 32399.65 µs, 2097.6 Gflops
iter 78: 32403.95 µs, 2097.4 Gflops
iter 79: 32442.81 µs, 2094.8 Gflops
iter 80: 32287.12 µs, 2104.9 Gflops
iter 81: 31784.30 µs, 2138.3 Gflops
iter 82: 31718.49 µs, 2142.7 Gflops
iter 83: 31781.20 µs, 2138.5 Gflops
iter 84: 32366.28 µs, 2099.8 Gflops
iter 85: 32500.74 µs, 2091.1 Gflops
iter 86: 32556.06 µs, 2087.6 Gflops
iter 87: 31500.82 µs, 2157.5 Gflops
iter 88: 32114.51 µs, 2116.3 Gflops
iter 89: 32172.92 µs, 2112.4 Gflops
iter 90: 32757.04 µs, 2074.8 Gflops
iter 91: 32593.25 µs, 2085.2 Gflops
iter 92: 31516.55 µs, 2156.4 Gflops
iter 93: 32395.12 µs, 2097.9 Gflops
iter 94: 33056.74 µs, 2055.9 Gflops
iter 95: 33230.78 µs, 2045.2 Gflops
iter 96: 33077.00 µs, 2054.7 Gflops
iter 97: 31499.62 µs, 2157.6 Gflops
iter 98: 31959.77 µs, 2126.5 Gflops
iter 99: 32049.18 µs, 2120.6 Gflops
Average: 14714.98 µs, 4618.6 Gflops
<class 'torch.Tensor'> torch.Size([4, 1, 256, 256])
<class 'torch.Tensor'> torch.Size([4, 8, 4])
<class 'torch.nn.parameter.Parameter'> torch.Size([16, 1, 25, 25])
<class 'torch.nn.parameter.Parameter'> torch.Size([16, 1, 49, 49])
<class 'torch.nn.parameter.Parameter'> torch.Size([16, 1, 97, 97])
<class 'torch.nn.parameter.Parameter'> torch.Size([16, 1, 193, 193])

Hi,

The GPU api is asynchronous. To get accurate measurement, you have to use torch.cuda.synchronize() before each time.time() call.
Also if you have cudnn, you can use torch.backends.cudnn.benchmark=True to get the very best performance (at the cost of the first forward being slower).

Thanks for the reply. The displayed running time is indeed much more consistent with a torch.cuda.synchronize() call before each time.time() call. The first forward is already very slow.

It is expected that the first forward is slow because we do lazy initialization CUDA. So during the first forward, you do a lot of initialization that you won’t need in the following runs.