Hi everyone,
I’m encountering issues with batch scaling and input resolution while testing a simple 2-layer convolution in PyTorch. I’d like some help understanding why the behavior I’m observing occurs. Below are the details:
Issue 1: Non-Linear Batch Scaling
The throughput doesn’t scale linearly as the batch size increases. For example, in the following test with 64×64 image resolution, the throughput is relatively flat for batch sizes 4, 8, and 16, but there’s a noticeable jump for batch sizes 64, 128, and especially 256.
Here are the logs for 64×64 resolution:
Batch size: 2, total_images: 10000, total time: 0.6013 s, throughput: 16629.97 images/s
Batch size: 4, total_images: 10000, total time: 0.3541 s, throughput: 28240.33 images/s
Batch size: 8, total_images: 10000, total time: 0.3828 s, throughput: 26123.36 images/s
Batch size: 16, total_images: 10000, total time: 0.3639 s, throughput: 27477.99 images/s
Batch size: 32, total_images: 10016, total time: 0.3550 s, throughput: 28213.21 images/s
Batch size: 64, total_images: 10048, total time: 0.2698 s, throughput: 37248.43 images/s
Batch size: 128, total_images: 10112, total time: 0.2140 s, throughput: 47253.55 images/s
Batch size: 256, total_images: 10240, total time: 0.1388 s, throughput: 73779.19 images/s
Is it expected for throughput to be relatively flat for smaller batch sizes and then increase significantly at larger batch sizes?
Issue 2: Throughput Depends on Input Resolution
When I change the input resolution, the scaling behavior changes significantly. Larger image resolutions result in lower throughput and less efficient scaling. Here are the logs for various resolutions:
Image Resolution: 64×64
Batch size: 2, total_images: 10000, total time: 0.6013 s, throughput: 16629.97 images/s
Batch size: 4, total_images: 10000, total time: 0.3541 s, throughput: 28240.33 images/s
Batch size: 8, total_images: 10000, total time: 0.3828 s, throughput: 26123.36 images/s
Batch size: 16, total_images: 10000, total time: 0.3639 s, throughput: 27477.99 images/s
Batch size: 32, total_images: 10016, total time: 0.3550 s, throughput: 28213.21 images/s
Batch size: 64, total_images: 10048, total time: 0.2698 s, throughput: 37248.43 images/s
Batch size: 128, total_images: 10112, total time: 0.2140 s, throughput: 47253.55 images/s
Batch size: 256, total_images: 10240, total time: 0.1388 s, throughput: 73779.19 images/s
Image Resolution: 128×128
Batch size: 2, total_images: 10000, total time: 2.5050 s, throughput: 3991.97 images/s
Batch size: 4, total_images: 10000, total time: 1.5800 s, throughput: 6328.93 images/s
Batch size: 8, total_images: 10000, total time: 1.5112 s, throughput: 6617.18 images/s
Batch size: 16, total_images: 10000, total time: 1.5851 s, throughput: 6308.77 images/s
Batch size: 32, total_images: 10016, total time: 1.3791 s, throughput: 7262.69 images/s
Batch size: 64, total_images: 10048, total time: 1.0614 s, throughput: 9466.93 images/s
Batch size: 128, total_images: 10112, total time: 0.8542 s, throughput: 11838.45 images/s
Batch size: 256, total_images: 10240, total time: 0.5507 s, throughput: 18595.67 images/s
Image Resolution: 256×256
Batch size: 2, total_images: 10000, total time: 10.0976 s, throughput: 990.33 images/s
Batch size: 4, total_images: 10000, total time: 6.4409 s, throughput: 1552.57 images/s
Batch size: 8, total_images: 10000, total time: 6.2662 s, throughput: 1595.86 images/s
Batch size: 16, total_images: 10000, total time: 6.3972 s, throughput: 1563.18 images/s
Batch size: 32, total_images: 10016, total time: 5.5516 s, throughput: 1804.16 images/s
Batch size: 64, total_images: 10048, total time: 4.2874 s, throughput: 2343.59 images/s
Batch size: 128, total_images: 10112, total time: 7.4968 s, throughput: 1348.85 images/s
Batch size: 256, total_images: 10240, total time: 7.5911 s, throughput: 1348.94 images/s
Image Resolution: 512×512
Batch size: 2, total_images: 10000, total time: 28.2943 s, throughput: 353.43 images/s
Batch size: 4, total_images: 10000, total time: 26.5243 s, throughput: 377.01 images/s
Batch size: 8, total_images: 10000, total time: 25.8578 s, throughput: 386.73 images/s
Batch size: 16, total_images: 10000, total time: 25.6787 s, throughput: 389.43 images/s
Batch size: 32, total_images: 10016, total time: 29.4280 s, throughput: 340.36 images/s
Batch size: 64, total_images: 10048, total time: 29.6016 s, throughput: 339.44 images/s
Batch size: 128, total_images: 10112, total time: 29.9043 s, throughput: 338.15 images/s
Batch size: 256, GPU ran out of memory
As shown above, larger image resolutions lead to a decrease in throughput scaling, and for 512×512, scaling nearly flattens out. My understanding was that throughput should scale by a relatively consistent factor regardless of input resolution. Is this behavior expected?
Code to Reproduce
Here’s the code I used for testing. This was run on a Google Colab instance with a T4 GPU:
import torch.nn as nn
import torch
import time
import math
class SimpleConvNet(nn.Module):
def __init__(self, in_filters=3, out_filters=3):
super().__init__()
self.conv2 = nn.Sequential(
nn.Conv2d(in_filters, 64, kernel_size=3, padding=1),
nn.Conv2d(64, out_filters, kernel_size=3, padding=1)
)
def forward(self, x):
x = self.conv2(x)
return x
test_model = SimpleConvNet().to(device='cuda')
batches = [2, 4, 8, 16, 32, 64, 128, 256]
image_dims = [int(512/8), int(512/4), int(512/2), int(512/1)]
examples = 10000
for image_dim in image_dims:
print(f"Image resolution: {image_dim} x {image_dim}")
for batch in batches:
loops = math.ceil(examples / batch)
dummy_input = torch.rand((batch, 3, image_dim, image_dim),
dtype=torch.float32, device='cuda')
# Warm-up
for _ in range(10):
with torch.no_grad():
_ = test_model(dummy_input)
torch.cuda.synchronize()
start_time = time.time()
for _ in range(loops):
with torch.no_grad():
prediction = test_model(dummy_input)
torch.cuda.synchronize()
elapsed = time.time() - start_time
total_images = batch * loops
print(f"Batch size: {batch}, total_images: {total_images}, total time: {elapsed:.4f} s, throughput: {total_images / elapsed:.2f} images/s")
Any insights into these issues or recommendations on how to optimize would be greatly appreciated! Thank you!