Batch Scaling and Input Resolution Issues in PyTorch Convolution Layers

Hi everyone,
I’m encountering issues with batch scaling and input resolution while testing a simple 2-layer convolution in PyTorch. I’d like some help understanding why the behavior I’m observing occurs. Below are the details:

Issue 1: Non-Linear Batch Scaling
The throughput doesn’t scale linearly as the batch size increases. For example, in the following test with 64×64 image resolution, the throughput is relatively flat for batch sizes 4, 8, and 16, but there’s a noticeable jump for batch sizes 64, 128, and especially 256.
Here are the logs for 64×64 resolution:

Batch size: 2, total_images: 10000, total time: 0.6013 s, throughput: 16629.97 images/s  
Batch size: 4, total_images: 10000, total time: 0.3541 s, throughput: 28240.33 images/s  
Batch size: 8, total_images: 10000, total time: 0.3828 s, throughput: 26123.36 images/s  
Batch size: 16, total_images: 10000, total time: 0.3639 s, throughput: 27477.99 images/s  
Batch size: 32, total_images: 10016, total time: 0.3550 s, throughput: 28213.21 images/s  
Batch size: 64, total_images: 10048, total time: 0.2698 s, throughput: 37248.43 images/s  
Batch size: 128, total_images: 10112, total time: 0.2140 s, throughput: 47253.55 images/s  
Batch size: 256, total_images: 10240, total time: 0.1388 s, throughput: 73779.19 images/s  

Is it expected for throughput to be relatively flat for smaller batch sizes and then increase significantly at larger batch sizes?

Issue 2: Throughput Depends on Input Resolution
When I change the input resolution, the scaling behavior changes significantly. Larger image resolutions result in lower throughput and less efficient scaling. Here are the logs for various resolutions:
Image Resolution: 64×64

Batch size: 2, total_images: 10000, total time: 0.6013 s, throughput: 16629.97 images/s
Batch size: 4, total_images: 10000, total time: 0.3541 s, throughput: 28240.33 images/s
Batch size: 8, total_images: 10000, total time: 0.3828 s, throughput: 26123.36 images/s
Batch size: 16, total_images: 10000, total time: 0.3639 s, throughput: 27477.99 images/s
Batch size: 32, total_images: 10016, total time: 0.3550 s, throughput: 28213.21 images/s
Batch size: 64, total_images: 10048, total time: 0.2698 s, throughput: 37248.43 images/s
Batch size: 128, total_images: 10112, total time: 0.2140 s, throughput: 47253.55 images/s
Batch size: 256, total_images: 10240, total time: 0.1388 s, throughput: 73779.19 images/s

Image Resolution: 128×128

Batch size: 2, total_images: 10000, total time: 2.5050 s, throughput: 3991.97 images/s  
Batch size: 4, total_images: 10000, total time: 1.5800 s, throughput: 6328.93 images/s  
Batch size: 8, total_images: 10000, total time: 1.5112 s, throughput: 6617.18 images/s  
Batch size: 16, total_images: 10000, total time: 1.5851 s, throughput: 6308.77 images/s  
Batch size: 32, total_images: 10016, total time: 1.3791 s, throughput: 7262.69 images/s  
Batch size: 64, total_images: 10048, total time: 1.0614 s, throughput: 9466.93 images/s  
Batch size: 128, total_images: 10112, total time: 0.8542 s, throughput: 11838.45 images/s  
Batch size: 256, total_images: 10240, total time: 0.5507 s, throughput: 18595.67 images/s

Image Resolution: 256×256

Batch size: 2, total_images: 10000, total time: 10.0976 s, throughput: 990.33 images/s  
Batch size: 4, total_images: 10000, total time: 6.4409 s, throughput: 1552.57 images/s  
Batch size: 8, total_images: 10000, total time: 6.2662 s, throughput: 1595.86 images/s  
Batch size: 16, total_images: 10000, total time: 6.3972 s, throughput: 1563.18 images/s  
Batch size: 32, total_images: 10016, total time: 5.5516 s, throughput: 1804.16 images/s  
Batch size: 64, total_images: 10048, total time: 4.2874 s, throughput: 2343.59 images/s  
Batch size: 128, total_images: 10112, total time: 7.4968 s, throughput: 1348.85 images/s  
Batch size: 256, total_images: 10240, total time: 7.5911 s, throughput: 1348.94 images/s

Image Resolution: 512×512

Batch size: 2, total_images: 10000, total time: 28.2943 s, throughput: 353.43 images/s  
Batch size: 4, total_images: 10000, total time: 26.5243 s, throughput: 377.01 images/s  
Batch size: 8, total_images: 10000, total time: 25.8578 s, throughput: 386.73 images/s  
Batch size: 16, total_images: 10000, total time: 25.6787 s, throughput: 389.43 images/s  
Batch size: 32, total_images: 10016, total time: 29.4280 s, throughput: 340.36 images/s  
Batch size: 64, total_images: 10048, total time: 29.6016 s, throughput: 339.44 images/s  
Batch size: 128, total_images: 10112, total time: 29.9043 s, throughput: 338.15 images/s  
Batch size: 256, GPU ran out of memory

As shown above, larger image resolutions lead to a decrease in throughput scaling, and for 512×512, scaling nearly flattens out. My understanding was that throughput should scale by a relatively consistent factor regardless of input resolution. Is this behavior expected?

Code to Reproduce
Here’s the code I used for testing. This was run on a Google Colab instance with a T4 GPU:

import torch.nn as nn
import torch
import time
import math

class SimpleConvNet(nn.Module):
    def __init__(self, in_filters=3, out_filters=3):
        super().__init__()
        self.conv2 = nn.Sequential(
            nn.Conv2d(in_filters, 64, kernel_size=3, padding=1),
            nn.Conv2d(64, out_filters, kernel_size=3, padding=1)
        )

    def forward(self, x):
        x = self.conv2(x)
        return x

test_model = SimpleConvNet().to(device='cuda')
batches = [2, 4, 8, 16, 32, 64, 128, 256]
image_dims = [int(512/8), int(512/4), int(512/2), int(512/1)]
examples = 10000

for image_dim in image_dims:
  print(f"Image resolution: {image_dim} x {image_dim}")
  for batch in batches:
      loops = math.ceil(examples / batch)
      dummy_input = torch.rand((batch, 3, image_dim, image_dim),
                              dtype=torch.float32, device='cuda')
      # Warm-up
      for _ in range(10):
          with torch.no_grad():
              _ = test_model(dummy_input)

      torch.cuda.synchronize() 
      start_time = time.time()
      for _ in range(loops):
          with torch.no_grad():
              prediction = test_model(dummy_input)
      torch.cuda.synchronize()
      elapsed = time.time() - start_time

      total_images = batch * loops
      print(f"Batch size: {batch}, total_images: {total_images}, total time: {elapsed:.4f} s, throughput: {total_images / elapsed:.2f} images/s")

Any insights into these issues or recommendations on how to optimize would be greatly appreciated! Thank you!

Your profiling is wrong since CUDA operations are executed asynchronously. You would need to synchronize the code before starting and stopping the host timers.

Thanks, @ptrblck, for the suggestion! I added the synchronization as advised (also updated the code above), but unfortunately, the results have worsened significantly. The batch operation no longer scales as expected. Below are the updated logs:

Logs:

Image resolution: 64 x 64

Batch size: 2, total_images: 10000, total time: 0.5889 s, throughput: 16981.11 images/s
Batch size: 4, total_images: 10000, total time: 0.3752 s, throughput: 26650.04 images/s
Batch size: 8, total_images: 10000, total time: 0.4081 s, throughput: 24506.01 images/s
Batch size: 16, total_images: 10000, total time: 0.4505 s, throughput: 22199.73 images/s
Batch size: 32, total_images: 10016, total time: 0.4338 s, throughput: 23087.40 images/s
Batch size: 64, total_images: 10048, total time: 0.4271 s, throughput: 23527.71 images/s
Batch size: 128, total_images: 10112, total time: 0.4314 s, throughput: 23440.83 images/s
Batch size: 256, total_images: 10240, total time: 0.4348 s, throughput: 23548.83 images/s

Image resolution: 128 x 128

Batch size: 2, total_images: 10000, total time: 1.6119 s, throughput: 6204.00 images/s
Batch size: 4, total_images: 10000, total time: 1.6142 s, throughput: 6194.87 images/s
Batch size: 8, total_images: 10000, total time: 1.5798 s, throughput: 6330.02 images/s
Batch size: 16, total_images: 10000, total time: 1.7125 s, throughput: 5839.52 images/s
Batch size: 32, total_images: 10016, total time: 1.7162 s, throughput: 5836.26 images/s
Batch size: 64, total_images: 10048, total time: 1.7177 s, throughput: 5849.66 images/s
Batch size: 128, total_images: 10112, total time: 1.7199 s, throughput: 5879.31 images/s
Batch size: 256, total_images: 10240, total time: 1.7449 s, throughput: 5868.54 images/s

Image resolution: 256 x 256

Batch size: 2, total_images: 10000, total time: 6.4418 s, throughput: 1552.36 images/s
Batch size: 4, total_images: 10000, total time: 6.4367 s, throughput: 1553.59 images/s
Batch size: 8, total_images: 10000, total time: 6.4346 s, throughput: 1554.10 images/s
Batch size: 16, total_images: 10000, total time: 6.8206 s, throughput: 1466.15 images/s
Batch size: 32, total_images: 10016, total time: 6.8271 s, throughput: 1467.09 images/s
Batch size: 64, total_images: 10048, total time: 6.8346 s, throughput: 1470.16 images/s
Batch size: 128, total_images: 10112, total time: 7.2054 s, throughput: 1403.40 images/s
Batch size: 256, total_images: 10240, total time: 7.3281 s, throughput: 1397.36 images/s

Image resolution: 512 x 512

Batch size: 2, total_images: 10000, total time: 27.9309 s, throughput: 358.03 images/s
Batch size: 4, total_images: 10000, total time: 29.0995 s, throughput: 343.65 images/s
Batch size: 8, total_images: 10000, total time: 28.2753 s, throughput: 353.67 images/s
Batch size: 16, total_images: 10000, total time: 27.5091 s, throughput: 363.52 images/s
Batch size: 32, total_images: 10016, total time: 30.8132 s, throughput: 325.06 images/s
Batch size: 64, total_images: 10048, total time: 30.7247 s, throughput: 327.03 images/s
Batch size: 128, total_images: 10112, total time: 30.8619 s, throughput: 327.65 images/s
Batch size: 256, GPU ran out of memory 

Observations:

  1. Batch scaling is flat or even regresses in many cases.
    For example, at resolutions of 128x128 and higher, the throughput stagnates even as the batch size increases.
  2. Input tensor resolution drastically impacts scaling.
    At higher resolutions like 256x256 and 512x512, the GPU utilization plateaus early, and increasing the batch size offers negligible improvement in throughput.

Is there an inherent limitation in how PyTorch handles large batch sizes and resolutions, or could this be a hardware bottleneck? Would appreciate any further insights!

The results are not worse, but are now reflecting the GPU execution time instead of the overhead of dispatching the kernels.

If the scaling is limited, you might be compute-bound and could verify it with e.g. Nsight Compute.

Thanks @ptrblck for the advice. I will check with nsight and update here.