Hi, I am getting a segmentation fault when running IPEX BF16 example with torch.autocast
.
Environment: pytorch 1.11.0 + intel-extension-for-pytorch 1.11.0
Code to reproduce the error:
import torch
import torchvision
import intel_extension_for_pytorch as ipex
LR = 0.001
DOWNLOAD = True
DATA = 'datasets/cifar10/'
transform = torchvision.transforms.Compose([
torchvision.transforms.Resize((224, 224)),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train_dataset = torchvision.datasets.CIFAR10(
root=DATA,
train=True,
transform=transform,
download=DOWNLOAD,
)
train_loader = torch.utils.data.DataLoader(
dataset=train_dataset,
batch_size=128
)
model = torchvision.models.resnet50()
model = model.to(memory_format=torch.channels_last)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr = LR, momentum=0.9)
model.train()
model, optimizer = ipex.optimize(model, optimizer=optimizer, dtype=torch.bfloat16)
for batch_idx, (data, target) in enumerate(train_loader):
optimizer.zero_grad()
with torch.autocast(device_type='cpu', dtype=torch.bfloat16):
# with torch.cpu.amp.autocast():
# Setting memory_format to torch.channels_last could improve performance with 4D input data. This is optional.
data = data.to(memory_format=torch.channels_last)
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
print(batch_idx)
I tried to debug it with gdb.
Error:
--Type <RET> for more, q to quit, c to continue without paging--
Thread 110 "python" received signal SIGSEGV, Segmentation fault.
[Switching to Thread 0x7f4d43875700 (LWP 38270)]
0x00007f4d07f26e90 in dnnl::impl::cpu::x64::brgemm_inner_product_bwd_weights_t<(dnnl::impl::cpu::x64::cpu_isa_t)79>::compute_diff_weights_and_bias(dnnl::impl::cpu::x64::brgemm_inner_product_bwd_weights_t<(dnnl::impl::cpu::x64::cpu_isa_t)79>::thread_info_t const*) const::{lambda(int, int, int)#2}::operator()(int, int, int) const ()
from /root/anaconda3/envs/nanoPytorch1.12/lib/python3.7/site-packages/intel_extension_for_pytorch/lib/libintel-ext-pt-cpu.so
The doc says that torch.autocast("cpu", args...)
is equivalent to torch.cpu.amp.autocast(args...)
, but with torch.cpu.amp.autocast()
works well.
So my question is that is there any difference between torch.cpu.amp.autocast(..)
and torch.autocast(..)