Why bfloat16 matmul is significantly slower than float32?

I am figuring out how should I use bfloat16 or float32 on CPU training and inference.

That’s the code I use to test.

import torch
import time

tensor_size = (1000, 1000)
num_iterations = 100

def perform_operations(data):
    start_time = time.time()
    squared = torch.square(data)
    squared_time = time.time() - start_time

    start_time = time.time()
    summed = torch.sum(data)
    sum_time = time.time() - start_time

    start_time = time.time()
    averaged = torch.mean(data)
    mean_time = time.time() - start_time

    start_time = time.time()
    transpose = data.t()
    transpose_time = time.time() - start_time

    start_time = time.time()
    matmul_result = torch.matmul(data, transpose)
    matmul_time = time.time() - start_time
    return squared_time, sum_time, mean_time, transpose_time, matmul_time

bfloat16_times = [0.0] * 5
float32_times = [0.0] * 5

for i in range(num_iterations):
    data = torch.randn(tensor_size, dtype=torch.float32)
    data_bfloat16 = data.to(torch.bfloat16)

    bfloat16_op_times = perform_operations(data_bfloat16)
    for j in range(5):
        bfloat16_times[j] += bfloat16_op_times[j]

    float32_op_times = perform_operations(data)
    for j in range(5):
        float32_times[j] += float32_op_times[j]

bfloat16_times = [time / num_iterations for time in bfloat16_times]
float32_times = [time / num_iterations for time in float32_times]

print("Average time for bfloat16:")
print("Squared: {:.6f} seconds".format(bfloat16_times[0]))
print("Sum: {:.6f} seconds".format(bfloat16_times[1]))
print("Mean: {:.6f} seconds".format(bfloat16_times[2]))
print("Transpose: {:.6f} seconds".format(bfloat16_times[3]))
print("Matmul: {:.6f} seconds".format(bfloat16_times[4]))

print("Average time for float32:")
print("Squared: {:.6f} seconds".format(float32_times[0]))
print("Sum: {:.6f} seconds".format(float32_times[1]))
print("Mean: {:.6f} seconds".format(float32_times[2]))
print("Transpose: {:.6f} seconds".format(float32_times[3]))
print("Matmul: {:.6f} seconds".format(float32_times[4]))

And I got the result:

Average time for bfloat16:
Squared: 0.000355 seconds
Sum: 0.000183 seconds
Mean: 0.000237 seconds
Transpose: 0.000027 seconds
Matmul: 0.299079 seconds

Average time for float32:
Squared: 0.000585 seconds
Sum: 0.000239 seconds
Mean: 0.000123 seconds
Transpose: 0.000010 seconds
Matmul: 0.005496 seconds

I wonder why bfloat16 faster than float32 in all other operations except torch.matmul().
The code runs on an intel 9th-gen core i7 cpu.