Hello, I have an A100 GPU which is supposed to be more performant on FP16 than FP32. If I test simple operations I get better results for FP16 and BFP16. However, for conv2d this is not the case. Example:
import torch
import time
import os
os.environ["TORCH_CUDNN_V8_API_ENABLED"] = "1"
device = "cuda:0"
bsz = 5012
shape = (bsz, 5, 28, 28)
# Create tensors
sfp32 = torch.arange(bsz*5*28*28).reshape(shape).type(torch.float32).to(device)
sfp16 = sfp32.half()
sbf16 = sfp32.to(torch.bfloat16)
# Simple FP32 operation
for idx in range(10000):
if idx == 100:
start_time = time.time()
sfp32_sum = torch.sum(sfp32)
torch.cuda.synchronize()
print("Simple FP32 operation:", (time.time() - start_time) / 9000)
# Simple FP16 operation
for idx in range(10000):
if idx == 100:
start_time = time.time()
sfp16_sum = torch.sum(sfp16)
torch.cuda.synchronize()
print("Simple FP16 operation:", (time.time() - start_time) / 9000)
# Simple BFLOAT16 operation
for idx in range(10000):
if idx == 100:
start_time = time.time()
sbf16_sum = torch.sum(sbf16)
torch.cuda.synchronize()
print("Simple BFLOAT16 operation:", (time.time() - start_time) / 9000)
# Performance of convolution layers
conv_fp32 = torch.nn.Conv2d(5, 640, kernel_size=3, padding=1).to(device)
conv_fp16 = torch.nn.Conv2d(5, 640, kernel_size=3, padding=1).half().to(device)
conv_bf16 = torch.nn.Conv2d(5, 640, kernel_size=3, padding=1).to(torch.bfloat16).to(device)
# Convolution with FP32
for idx in range(1000):
if idx == 100:
start_time = time.time()
conv_out_fp32 = conv_fp32(sfp32)
torch.cuda.synchronize()
print("Conv FP32:", (time.time() - start_time) / 900)
# Convolution with FP16
for idx in range(1000):
if idx == 100:
start_time = time.time()
conv_out_fp16 = conv_fp16(sfp16)
torch.cuda.synchronize()
print("Conv FP16:", (time.time() - start_time) / 900)
# Convolution with BFLOAT16
for idx in range(1000):
if idx == 100:
start_time = time.time()
conv_out_bf16 = conv_bf16(sbf16)
torch.cuda.synchronize()
print("Conv BFLOAT16:", (time.time() - start_time) / 900)
Simple FP32 operation: 5.93457751803928e-05
Simple FP16 operation: 3.520531124538846e-05
Simple BFLOAT16 operation: 3.5241656833224824e-05
Conv FP32: 0.025614858998192682
Conv FP16: 0.05739478005303277
Conv BFLOAT16: 0.058520449267493356
Any idea why?
Interestingly by looking at nsys report, I see different operations for float32 than the other two. Not sure if I should look for something specific in the report.
Float32:
Float16:
For context.
- Driver Version: 550.54.15
- Cuda version: 12.4
- GPU: NVIDIA A100 80GB PCIe
- Torch version: ‘2.4.0a0+git705346b’ (tried also ‘2.3.0+cu121’)
- cudnn version: 90101
Thanks!
Update: with linear layers I get normal results. bfloat16 and float16 are much more performant than float32. So I think the issue is conv2d.
bsz = 5012
# Performance of linear layers
linear_fp32 = torch.nn.Linear(5 * 28 * 28, 2048).to(device)
linear_fp16 = torch.nn.Linear(5 * 28 * 28, 2048).half().to(device)
linear_bf16 = torch.nn.Linear(5 * 28 * 28, 2048).to(torch.bfloat16).to(device)
# Flatten the input tensors for linear layers
flat_sfp32 = sfp32.view(bsz, -1)
flat_sfp16 = sfp16.view(bsz, -1)
flat_sbf16 = sbf16.view(bsz, -1)
# Linear with FP32
for idx in range(1000):
if idx == 100:
start_time = time.time()
linear_out_fp32 = linear_fp32(flat_sfp32)
torch.cuda.synchronize()
print("Linear FP32:", (time.time() - start_time) / 900)
# Linear with FP16
for idx in range(1000):
if idx == 100:
start_time = time.time()
linear_out_fp16 = linear_fp16(flat_sfp16)
torch.cuda.synchronize()
print("Linear FP16:", (time.time() - start_time) / 900)
# Linear with BFLOAT16
for idx in range(1000):
if idx == 100:
start_time = time.time()
linear_out_bf16 = linear_bf16(flat_sbf16)
torch.cuda.synchronize()
print("Linear BFLOAT16:", (time.time() - start_time) / 900)
---
Linear FP32: 0.004868701299031575
Linear FP16: 0.000530137750837538
Linear BFLOAT16: 0.0003604027960035536