Forward pass gets 10000x slower after iterating for a while

I implemented a simple Deconv network just like pytorch’s official DCGAN tutorial.
I repeatedly pass a zeros vector to it. The time taken slows down significantly after a while. I am wondering what the reason is and how I can resolve it.

Code:

import torch
import torch.nn as nn
import time

# JUST TO MEASURE TIME
class Timer:
    def __init__(self, msg):
        self.msg = msg
        
    def __enter__(self):
        self.start = time.process_time()
        return self

    def __exit__(self, *args):
        self.end = time.process_time()
        self.interval = self.end - self.start
        
        print('{}: {:.5f}'.format(self.msg, self.interval))
        
device = torch.device("cuda")

ngf, nc, nz, batchSize = 64, 1, 6, 1<<16
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d( nz, ngf * 4, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 4 x 4
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 8 x 8
            nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. (ngf) x 16 x 16
            nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. (nc) x 32 x 32
        )

    def forward(self, input):
        return self.main(input)
    
# Create the generator
netG = Generator().to(device)

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

netG.apply(weights_init)

# torch.backends.cudnn.benchmark=True

while True:
    with Timer('Time elapsed'):
        with torch.no_grad():
            netG(torch.zeros([batchSize, nz, 1, 1], device=device))

Result:

Time elapsed: 0.02309
Time elapsed: 0.00072
Time elapsed: 0.00208
Time elapsed: 0.00128
Time elapsed: 0.00119
Time elapsed: 0.00153
Time elapsed: 0.00176
Time elapsed: 0.00170
Time elapsed: 0.00185
Time elapsed: 0.00188
Time elapsed: 0.00191
Time elapsed: 0.00190
Time elapsed: 0.00171
Time elapsed: 0.00176
Time elapsed: 0.00167
Time elapsed: 0.00120
Time elapsed: 0.00168
Time elapsed: 0.00169
Time elapsed: 0.00166
Time elapsed: 0.00167
Time elapsed: 0.00171
Time elapsed: 0.00168
Time elapsed: 0.00168
Time elapsed: 0.00168
Time elapsed: 0.00169
Time elapsed: 0.00177
Time elapsed: 0.00173
Time elapsed: 0.00176
Time elapsed: 0.00173
Time elapsed: 0.00171
Time elapsed: 0.00168
Time elapsed: 0.00173
Time elapsed: 0.00168
Time elapsed: 0.00178
Time elapsed: 0.00169
Time elapsed: 0.00171
Time elapsed: 0.00168
Time elapsed: 0.00169
Time elapsed: 0.00169
Time elapsed: 0.00173
Time elapsed: 0.00154
Time elapsed: 0.00170
Time elapsed: 0.00167
Time elapsed: 0.00224
Time elapsed: 0.00117
Time elapsed: 0.00175
Time elapsed: 0.00168
Time elapsed: 0.00173
Time elapsed: 0.00169
Time elapsed: 12.62760
Time elapsed: 12.71425
Time elapsed: 12.71379
Time elapsed: 12.71846
Time elapsed: 12.71909
Time elapsed: 12.71898
Time elapsed: 12.72288
Time elapsed: 12.72157
Time elapsed: 12.72226
Time elapsed: 12.72456
Time elapsed: 12.72350
Time elapsed: 12.72480
Time elapsed: 12.72644
Time elapsed: 12.72337
Time elapsed: 12.72424
Time elapsed: 12.72538
Time elapsed: 12.72533
Time elapsed: 12.72510
Time elapsed: 12.72507
Time elapsed: 12.72806
Time elapsed: 12.72865
Time elapsed: 12.72764
Time elapsed: 12.72431

My GPU: Titan RTX
PyTorch version: 1.4

This is a problem particularly when I take a very large batch size for GPU memory.

CUDA calls are asynchronous, so your current Timer class will most likely only time the Python overhead and the kernel launch.
To properly time the kernels, you would have to synchronize the code before starting and stopping the timer via torch.cuda.synchronize().

Your point is correct, but regardless, the overall code was previously taking unusually long.
After I wrapped my code with torch.cuda.synchronize() things got normal. I assume the problem previously was that a long buffer of these calls was somehow forming and ruining everything. Am I correct?

Now each batch is taking 0.25 seconds, is that normal?

That seems unlikely, as you are passing 65k samples at once.
I see a forward time of ~12.5s per iterations on a V100.

2 Likes