Segmentation fault when running IPEX bf16 example with torch.autocast

Hi, I am getting a segmentation fault when running IPEX BF16 example with torch.autocast.
Environment: pytorch 1.11.0 + intel-extension-for-pytorch 1.11.0
Code to reproduce the error:

import torch
import torchvision
import intel_extension_for_pytorch as ipex

LR = 0.001
DOWNLOAD = True
DATA = 'datasets/cifar10/'

transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((224, 224)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train_dataset = torchvision.datasets.CIFAR10(
        root=DATA,
        train=True,
        transform=transform,
        download=DOWNLOAD,
)
train_loader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_size=128
)

model = torchvision.models.resnet50()
model = model.to(memory_format=torch.channels_last)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr = LR, momentum=0.9)
model.train()
model, optimizer = ipex.optimize(model, optimizer=optimizer, dtype=torch.bfloat16)

for batch_idx, (data, target) in enumerate(train_loader):
    optimizer.zero_grad()
    with torch.autocast(device_type='cpu', dtype=torch.bfloat16):
    # with torch.cpu.amp.autocast():
        # Setting memory_format to torch.channels_last could improve performance with 4D input data. This is optional.
        data = data.to(memory_format=torch.channels_last)
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
    optimizer.step()
    print(batch_idx)

I tried to debug it with gdb.
Error:

--Type <RET> for more, q to quit, c to continue without paging--

Thread 110 "python" received signal SIGSEGV, Segmentation fault.
[Switching to Thread 0x7f4d43875700 (LWP 38270)]
0x00007f4d07f26e90 in dnnl::impl::cpu::x64::brgemm_inner_product_bwd_weights_t<(dnnl::impl::cpu::x64::cpu_isa_t)79>::compute_diff_weights_and_bias(dnnl::impl::cpu::x64::brgemm_inner_product_bwd_weights_t<(dnnl::impl::cpu::x64::cpu_isa_t)79>::thread_info_t const*) const::{lambda(int, int, int)#2}::operator()(int, int, int) const ()
   from /root/anaconda3/envs/nanoPytorch1.12/lib/python3.7/site-packages/intel_extension_for_pytorch/lib/libintel-ext-pt-cpu.so

The doc says that torch.autocast("cpu", args...) is equivalent to torch.cpu.amp.autocast(args...), but with torch.cpu.amp.autocast() works well.
So my question is that is there any difference between torch.cpu.amp.autocast(..) and torch.autocast(..)

Could you test if this code still fails in the latest PyTorch release and if so, create an issue on GitHub, please?

Thank you for your timely reply!
The segmentation fault disappears and the code works well in PyTorch1.12.1+cpu.