I am figuring out how should I use bfloat16 or float32 on CPU training and inference.
That’s the code I use to test.
import torch
import time
tensor_size = (1000, 1000)
num_iterations = 100
def perform_operations(data):
start_time = time.time()
squared = torch.square(data)
squared_time = time.time() - start_time
start_time = time.time()
summed = torch.sum(data)
sum_time = time.time() - start_time
start_time = time.time()
averaged = torch.mean(data)
mean_time = time.time() - start_time
start_time = time.time()
transpose = data.t()
transpose_time = time.time() - start_time
start_time = time.time()
matmul_result = torch.matmul(data, transpose)
matmul_time = time.time() - start_time
return squared_time, sum_time, mean_time, transpose_time, matmul_time
bfloat16_times = [0.0] * 5
float32_times = [0.0] * 5
for i in range(num_iterations):
data = torch.randn(tensor_size, dtype=torch.float32)
data_bfloat16 = data.to(torch.bfloat16)
bfloat16_op_times = perform_operations(data_bfloat16)
for j in range(5):
bfloat16_times[j] += bfloat16_op_times[j]
float32_op_times = perform_operations(data)
for j in range(5):
float32_times[j] += float32_op_times[j]
bfloat16_times = [time / num_iterations for time in bfloat16_times]
float32_times = [time / num_iterations for time in float32_times]
print("Average time for bfloat16:")
print("Squared: {:.6f} seconds".format(bfloat16_times[0]))
print("Sum: {:.6f} seconds".format(bfloat16_times[1]))
print("Mean: {:.6f} seconds".format(bfloat16_times[2]))
print("Transpose: {:.6f} seconds".format(bfloat16_times[3]))
print("Matmul: {:.6f} seconds".format(bfloat16_times[4]))
print()
print("Average time for float32:")
print("Squared: {:.6f} seconds".format(float32_times[0]))
print("Sum: {:.6f} seconds".format(float32_times[1]))
print("Mean: {:.6f} seconds".format(float32_times[2]))
print("Transpose: {:.6f} seconds".format(float32_times[3]))
print("Matmul: {:.6f} seconds".format(float32_times[4]))
And I got the result:
Average time for bfloat16:
Squared: 0.000355 seconds
Sum: 0.000183 seconds
Mean: 0.000237 seconds
Transpose: 0.000027 seconds
Matmul: 0.299079 seconds
Average time for float32:
Squared: 0.000585 seconds
Sum: 0.000239 seconds
Mean: 0.000123 seconds
Transpose: 0.000010 seconds
Matmul: 0.005496 seconds
I wonder why bfloat16 faster than float32 in all other operations except torch.matmul().
The code runs on an intel 9th-gen core i7 cpu.