RuntimeError when using amp on cpu with crossentroweights

Hi, I have an error when using amp on cpu + weights on criterion.
It looks like a internal bug to me, because amp with criterion weights works fine on gpu, but it doesn’t on cpu.

Code:

import torch
import torch.nn as nn

# %%
def test(x32, x16, y, m, crit, device_type):
    print(device_type)

    # Standard computation
    x_pred = m(x32)
    loss = crit(x_pred, y)
    print(1, loss)

    try:
        loss = crit(x16, y)
        print(2, loss.dtype)
    except Exception as e:
        print(2, e)
    # amp
    with torch.autocast(device_type=device_type):
        x_pred = m(x32)
        try:
            loss = crit(x_pred, y) # Error in cpu + criterion weights
            print(3, loss)
        except Exception as e:
            print(3, e)
        print(4, x32.dtype, x_pred.dtype, loss.dtype, y)

        try:
            loss = crit(x16, y) # Also Error in cpu + criterion weights
            print(5, loss.dtype)
        except Exception as e:
            print(5, e)
    print(6, loss.dtype)

# %% markdown
## GPU + Without weights in criterion
# %%
crit = nn.CrossEntropyLoss().cuda()

m = nn.Linear(5,5)
m.cuda()

# %%
x32 = torch.rand(10,5).float().cuda()
x16 = torch.rand(10,5, dtype=torch.bfloat16).cuda()
y = torch.randint(5, size=(10,)).cuda()

# %%
test(x32, x16, y, m, crit, 'cuda')

# %% markdown
## GPU + With weights in criterion
# %%
weight = torch.rand(5)
crit = nn.CrossEntropyLoss(weight=weight).cuda()

m = nn.Linear(5,5)
m.cuda()

# %%
x32 = torch.rand(10,5).float().cuda()
x16 = torch.rand(10,5, dtype=torch.bfloat16).cuda()
y = torch.randint(5, size=(10,)).cuda()

# %%
test(x32, x16, y, m, crit, 'cuda')

# %% markdown
## CPU + Without weights in criterion
 # %%
crit = nn.CrossEntropyLoss()

m = nn.Linear(5,5)

# %%
x32 = torch.rand(10,5).float()
x16 = torch.rand(10,5, dtype=torch.bfloat16)
y = torch.randint(5, size=(10,))

# %%
test(x32, x16, y, m, crit, 'cpu')

# %% markdown
## CPU + With weights in criterion
# %%
weight = torch.rand(5)
crit = nn.CrossEntropyLoss(weight=weight)

m = nn.Linear(5,5)

# %%
x32 = torch.rand(10,5).float()
x16 = torch.rand(10,5, dtype=torch.bfloat16)
y = torch.randint(5, size=(10,))

# %%
test(x32, x16, y, m, crit, 'cpu')

Output:

cuda
1 tensor(1.8212, device='cuda:0', grad_fn=<NllLossBackward0>)
2 torch.bfloat16
3 tensor(1.8212, device='cuda:0', grad_fn=<NllLossBackward0>)
4 torch.float32 torch.float16 torch.float32 tensor([0, 1, 4, 0, 2, 4, 0, 1, 3, 3], device='cuda:0')
5 torch.float32
6 torch.float32
cuda
1 tensor(1.2175, device='cuda:0', grad_fn=<NllLossBackward0>)
2 expected scalar type BFloat16 but found Float
3 tensor(1.2176, device='cuda:0', grad_fn=<NllLossBackward0>)
4 torch.float32 torch.float16 torch.float32 tensor([1, 4, 3, 1, 2, 1, 1, 0, 0, 4], device='cuda:0')
5 torch.float32
6 torch.float32
cpu
1 tensor(1.5414, grad_fn=<NllLossBackward0>)
2 torch.bfloat16
3 tensor(1.5469, dtype=torch.bfloat16, grad_fn=<NllLossBackward0>)
4 torch.float32 torch.bfloat16 torch.bfloat16 tensor([3, 2, 2, 2, 3, 0, 2, 0, 2, 2])
5 torch.bfloat16
6 torch.bfloat16
cpu
1 tensor(1.7708, grad_fn=<NllLossBackward0>)
2 expected scalar type BFloat16 but found Float
3 expected scalar type BFloat16 but found Float
4 torch.float32 torch.bfloat16 torch.float32 tensor([1, 4, 2, 0, 4, 4, 3, 2, 2, 4])
5 expected scalar type BFloat16 but found Float
6 torch.float32

Hi, Could somebody have a look into this?

Long story short, nn.CrossEntropyLoss() with weights arguments gives an internal error when running on cpu with amp.
It works fine when it’s running on gpu.

Could you check if you are still seeing the same error in the latest nightly?
If so, could you create an issue on GitHub so that the code owners can take a look at it, please?