RuntimeError when using amp on cpu with crossentroweights

Hi, I have an error when using amp on cpu + weights on criterion.
It looks like a internal bug to me, because amp with criterion weights works fine on gpu, but it doesn’t on cpu.

Code:

import torch
import torch.nn as nn

# %%
def test(x32, x16, y, m, crit, device_type):
    print(device_type)

    # Standard computation
    x_pred = m(x32)
    loss = crit(x_pred, y)
    print(1, loss)

    try:
        loss = crit(x16, y)
        print(2, loss.dtype)
    except Exception as e:
        print(2, e)
    # amp
    with torch.autocast(device_type=device_type):
        x_pred = m(x32)
        try:
            loss = crit(x_pred, y) # Error in cpu + criterion weights
            print(3, loss)
        except Exception as e:
            print(3, e)
        print(4, x32.dtype, x_pred.dtype, loss.dtype, y)

        try:
            loss = crit(x16, y) # Also Error in cpu + criterion weights
            print(5, loss.dtype)
        except Exception as e:
            print(5, e)
    print(6, loss.dtype)

# %% markdown
## GPU + Without weights in criterion
# %%
crit = nn.CrossEntropyLoss().cuda()

m = nn.Linear(5,5)
m.cuda()

# %%
x32 = torch.rand(10,5).float().cuda()
x16 = torch.rand(10,5, dtype=torch.bfloat16).cuda()
y = torch.randint(5, size=(10,)).cuda()

# %%
test(x32, x16, y, m, crit, 'cuda')

# %% markdown
## GPU + With weights in criterion
# %%
weight = torch.rand(5)
crit = nn.CrossEntropyLoss(weight=weight).cuda()

m = nn.Linear(5,5)
m.cuda()

# %%
x32 = torch.rand(10,5).float().cuda()
x16 = torch.rand(10,5, dtype=torch.bfloat16).cuda()
y = torch.randint(5, size=(10,)).cuda()

# %%
test(x32, x16, y, m, crit, 'cuda')

# %% markdown
## CPU + Without weights in criterion
 # %%
crit = nn.CrossEntropyLoss()

m = nn.Linear(5,5)

# %%
x32 = torch.rand(10,5).float()
x16 = torch.rand(10,5, dtype=torch.bfloat16)
y = torch.randint(5, size=(10,))

# %%
test(x32, x16, y, m, crit, 'cpu')

# %% markdown
## CPU + With weights in criterion
# %%
weight = torch.rand(5)
crit = nn.CrossEntropyLoss(weight=weight)

m = nn.Linear(5,5)

# %%
x32 = torch.rand(10,5).float()
x16 = torch.rand(10,5, dtype=torch.bfloat16)
y = torch.randint(5, size=(10,))

# %%
test(x32, x16, y, m, crit, 'cpu')

Output:

cuda
1 tensor(1.8212, device='cuda:0', grad_fn=<NllLossBackward0>)
2 torch.bfloat16
3 tensor(1.8212, device='cuda:0', grad_fn=<NllLossBackward0>)
4 torch.float32 torch.float16 torch.float32 tensor([0, 1, 4, 0, 2, 4, 0, 1, 3, 3], device='cuda:0')
5 torch.float32
6 torch.float32
cuda
1 tensor(1.2175, device='cuda:0', grad_fn=<NllLossBackward0>)
2 expected scalar type BFloat16 but found Float
3 tensor(1.2176, device='cuda:0', grad_fn=<NllLossBackward0>)
4 torch.float32 torch.float16 torch.float32 tensor([1, 4, 3, 1, 2, 1, 1, 0, 0, 4], device='cuda:0')
5 torch.float32
6 torch.float32
cpu
1 tensor(1.5414, grad_fn=<NllLossBackward0>)
2 torch.bfloat16
3 tensor(1.5469, dtype=torch.bfloat16, grad_fn=<NllLossBackward0>)
4 torch.float32 torch.bfloat16 torch.bfloat16 tensor([3, 2, 2, 2, 3, 0, 2, 0, 2, 2])
5 torch.bfloat16
6 torch.bfloat16
cpu
1 tensor(1.7708, grad_fn=<NllLossBackward0>)
2 expected scalar type BFloat16 but found Float
3 expected scalar type BFloat16 but found Float
4 torch.float32 torch.bfloat16 torch.float32 tensor([1, 4, 2, 0, 4, 4, 3, 2, 2, 4])
5 expected scalar type BFloat16 but found Float
6 torch.float32