Torch.cuda.amp inferencing slower than normal

I am trying to infer results out of a normal resnet18 model present in torchvision.models attribute. The model is simply trained without any mixed precision learning, purely on FP32.
However, I want to get faster results while inferencing, so I enabled torch.cuda.amp.autocast() function only while running a test inference case.

The code for the same is given below -

model = torchvision.models.resnet18()
model = model.to(device) # Pushing to GPU

# Train the model normally

Without amp -

tensor = torch.rand(1,3,32,32).to(device) # Random tensor for testing
with torch.no_grad():
  model.eval()
  start = torch.cuda.Event(enable_timing=True)
  end = torch.cuda.Event(enable_timing=True)
  model(tensor) # warmup
  model(tensor) # warmpup
  start.record()
  for i in range(20): # total time over 20 iterations 
    model(tensor)
  end.record()
  torch.cuda.synchronize()
    
  print('execution time in milliseconds: {}'. format(start.elapsed_time(end)/20))

  execution time in milliseconds: 5.264944076538086

With amp -

tensor = torch.rand(1,3,32,32).to(device)
with torch.no_grad():
  model.eval()
  start = torch.cuda.Event(enable_timing=True)
  end = torch.cuda.Event(enable_timing=True)
  model(tensor)
  model(tensor)

  start.record()
  with torch.cuda.amp.autocast(): # autocast initialized
    for i in range(20):
      model(tensor)
  end.record()
  torch.cuda.synchronize()
  
  print('execution time in milliseconds: {}'. format(start.elapsed_time(end)/20))

  execution time in milliseconds: 10.619884490966797

Clearly, the autocast() enabled code is taking double the time. Even, with larger models like resnet50, the timing variation is approximately the same.

Can someone help me out regarding this ? I am running this example on Google Colab and below are the specifications of the GPU

+-----------------------------------------------------------------------------+
| NVIDIA-SMI 465.27       Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   43C    P0    28W / 250W |      0MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+
torch.version.cuda == 10.1
torch.__version__  == 1.8.1+cu101

The P100 doesn’t have TensorCores, so while I wouldn’t expect a slowdown (this seems to be bad) I also wouldn’t expect to see a huge increase in performance.

UPDATE : I executed the same code mentioned above, but on a different GPU, Tesla T4 (around 320 Tensor cores). There seems to be much of an improvement in the execution time, with and without amp

Without amp -

execution time in milliseconds: 3.9147518157958983

With amp -

execution time in milliseconds: 3.4673088073730467

The execution time with autocasting is slightly better than the one without autocasting. However, the time difference is not that great as expected (atleast 2x speedup would have been preferable).

What can be the reason for this ?
Is there any bug in the code ?
Is the GPU ineffective ?
Is the resnet18 model too small and simple to show any significant execution time difference ?

Reminder

I’d be grateful if someone could answer the questions asked in the previous response !

It would be great if someone could provide an answer to this question.

I tested autocast in my code and also observed that it is taking much longer to run

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def train_fn(loader, model, optimizer, loss_fn, scaler):
    loop = tqdm(loader)
    
    for batch_index, (data, targets) in enumerate(loop):
        data = data.to(device)
        targets = targets.to(device)
        
        # Forward
        with torch.autocast(device, dtype=torch.bfloat16, cache_enabled=True):
            predictions = model(data)
            loss = loss_fn(predictions, targets)
        
        predictions = model(data)
        loss = loss_fn(predictions, targets)
        
        # Backward
        optimizer.zero_grad()
        if device == "cuda":
            scaler.scale(loss).backward()
            scaler.update()
        else:
            loss.backward()
            optimizer.step()

        # Update tqdm loop
        loop.set_postfix(loss=loss.item())

Could you provide some more details about the conditions required to reproduce this? (e.g., CPU, GPU, model definition that reproduces this?)

If you are using a convolutional architecture on CUDA, fp16 autocast may be faster as bfloat16 support is still fresh and may not be available on your installation. (Try setting the environment variable TORCH_CUDNN_V8_API_ENABLED=1 if you want to check this: [cuDNN v8] Extend current cuDNN convolution v8 API binding to support BFloat16 · Issue #58861 · pytorch/pytorch (github.com))

Additionally, AMP can speed up computation/math-bottlenecked models, but it isn’t expected to speed up memory bandwidth-bottlenecked models (e.g., those that have many pointwise operations or operations with low arithmetic intensity) as it effectively trades lower precision computation for additional pointwise operations.

Dear @eqy,

Probably there was something wrong with my network or some configuration.
In the end, I observed a speed up of about 3 times with autocasting.

Thanks for your support.

1 Like