Autocast on cpu dramatically slow

Hi,

On a toy regression model with pytorch 2.1.2 on cpu, torch.autocast is really slow.
Without with torch.autocast(device_type="cpu", dtype=torch.bfloat16, cache_enabled=True) in the code below, I get

Epoch 0: Loss: 5.163879e-01 in 2.6s.
Epoch 1: Loss: 5.981598e+01 in 2.5s.
Epoch 2: Loss: 6.057042e-01 in 2.5s.
Epoch 3: Loss: 3.923391e+00 in 2.5s.
Epoch 4: Loss: 1.904698e+00 in 2.5s.
training done in 12.7s.

With torch.autocast, I get

Epoch 0: Loss: 5.164276e-01 in 230.8s.
Epoch 1: Loss: 5.980967e+01 in 231.1s.
Epoch 2: Loss: 6.053763e-01 in 233.0s.
Epoch 3: Loss: 3.924905e+00 in 232.8s.
Epoch 4: Loss: 1.904750e+00 in 232.0s.
training done in 1159.7s.

So, the code with autocast is 2 orders of magnitude slower.

Am I using autocast correctly in the code below ? thanks

import torch
import torch.nn as nn
import numpy as np
from time import time
torch.set_num_threads(1)

input_size = 1
output_size = 1
hidden_size = 512
num_data = 100000

# seeds
seed = 1234
torch.manual_seed(seed)
torch.use_deterministic_algorithms(True)
np.random.seed(seed)

# hyper-parameters
num_epochs = 5
learning_rate = 0.01

# toy dataset
x_train = np.random.rand(num_data,input_size)
y_train = np.cos(2*np.pi*x_train) + 0.1*np.random.randn(num_data,input_size)

# regression model
model = nn.Sequential(nn.Linear(input_size, hidden_size),
                       nn.GELU(),
                       nn.Linear(hidden_size, hidden_size),
                       nn.GELU(),
                       nn.Linear(hidden_size, output_size))

# loss and optimizer
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay = 1e-2)  

# train the model
x_train = torch.from_numpy(x_train.astype(np.float32))
y_train = torch.from_numpy(y_train.astype(np.float32))

global_start_time = time()
for epoch in range(num_epochs):
    start_time = time()
    
    # forward pass
    with torch.autocast(device_type="cpu", dtype=torch.bfloat16, cache_enabled=True):
        outputs = model(x_train)
        loss = criterion(outputs, y_train)
    
    # backward and optimize
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    
    print(f'Epoch {epoch}: Loss: {loss.item():.6e} in {time()-start_time:.1f}s.')
print(f'training done in {time()-global_start_time:.1f}s.')

PyTorch built with:

  • GCC 9.3
  • C++ Version: 201703
  • Intel(R) oneAPI Math Kernel Library Version 2022.1-Product Build 20220311 for Intel(R) 64 architecture applications
  • Intel(R) MKL-DNN v3.1.1 (Git Hash 64f6bcbcbab628e96f33a62c3e975f8535a7bde4)
  • OpenMP 201511 (a.k.a. OpenMP 4.5)
  • LAPACK is enabled (usually provided by MKL)
  • NNPACK is enabled
  • CPU capability usage: AVX2
  • Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CXX_COMPILER=/opt/rh/devtoolset-9/root/usr/bin/c++, CXX_FLAGS= -D_GLIBCXX_USE_CXX11_ABI=0 -fabi-version=11 -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOCUPTI -DLIBKINETO_NOROCTRACER -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Werror=bool-operation -Wnarrowing -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-strict-overflow -Wno-strict-aliasing -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=old-style-cast -Wno-invalid-partial-specialization -Wno-unused-private-field -Wno-aligned-allocation-unavailable -Wno-missing-braces -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Werror=cast-function-type -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_DISABLE_GPU_ASSERTS=ON, TORCH_VERSION=2.1.2, USE_CUDA=0, USE_CUDNN=OFF, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=OFF, USE_NNPACK=ON, USE_OPENMP=ON, USE_ROCM=OFF,
1 Like

I don’t see anything obviously wrong in your code so could you create an issue on GitHub, please?

Thanks for your reply.
I will create an issue.

#118499

2 Likes