Hi, I have an error when using amp on cpu + weights on criterion.
It looks like a internal bug to me, because amp with criterion weights works fine on gpu, but it doesn’t on cpu.
Code:
import torch
import torch.nn as nn
# %%
def test(x32, x16, y, m, crit, device_type):
print(device_type)
# Standard computation
x_pred = m(x32)
loss = crit(x_pred, y)
print(1, loss)
try:
loss = crit(x16, y)
print(2, loss.dtype)
except Exception as e:
print(2, e)
# amp
with torch.autocast(device_type=device_type):
x_pred = m(x32)
try:
loss = crit(x_pred, y) # Error in cpu + criterion weights
print(3, loss)
except Exception as e:
print(3, e)
print(4, x32.dtype, x_pred.dtype, loss.dtype, y)
try:
loss = crit(x16, y) # Also Error in cpu + criterion weights
print(5, loss.dtype)
except Exception as e:
print(5, e)
print(6, loss.dtype)
# %% markdown
## GPU + Without weights in criterion
# %%
crit = nn.CrossEntropyLoss().cuda()
m = nn.Linear(5,5)
m.cuda()
# %%
x32 = torch.rand(10,5).float().cuda()
x16 = torch.rand(10,5, dtype=torch.bfloat16).cuda()
y = torch.randint(5, size=(10,)).cuda()
# %%
test(x32, x16, y, m, crit, 'cuda')
# %% markdown
## GPU + With weights in criterion
# %%
weight = torch.rand(5)
crit = nn.CrossEntropyLoss(weight=weight).cuda()
m = nn.Linear(5,5)
m.cuda()
# %%
x32 = torch.rand(10,5).float().cuda()
x16 = torch.rand(10,5, dtype=torch.bfloat16).cuda()
y = torch.randint(5, size=(10,)).cuda()
# %%
test(x32, x16, y, m, crit, 'cuda')
# %% markdown
## CPU + Without weights in criterion
# %%
crit = nn.CrossEntropyLoss()
m = nn.Linear(5,5)
# %%
x32 = torch.rand(10,5).float()
x16 = torch.rand(10,5, dtype=torch.bfloat16)
y = torch.randint(5, size=(10,))
# %%
test(x32, x16, y, m, crit, 'cpu')
# %% markdown
## CPU + With weights in criterion
# %%
weight = torch.rand(5)
crit = nn.CrossEntropyLoss(weight=weight)
m = nn.Linear(5,5)
# %%
x32 = torch.rand(10,5).float()
x16 = torch.rand(10,5, dtype=torch.bfloat16)
y = torch.randint(5, size=(10,))
# %%
test(x32, x16, y, m, crit, 'cpu')
Output:
cuda
1 tensor(1.8212, device='cuda:0', grad_fn=<NllLossBackward0>)
2 torch.bfloat16
3 tensor(1.8212, device='cuda:0', grad_fn=<NllLossBackward0>)
4 torch.float32 torch.float16 torch.float32 tensor([0, 1, 4, 0, 2, 4, 0, 1, 3, 3], device='cuda:0')
5 torch.float32
6 torch.float32
cuda
1 tensor(1.2175, device='cuda:0', grad_fn=<NllLossBackward0>)
2 expected scalar type BFloat16 but found Float
3 tensor(1.2176, device='cuda:0', grad_fn=<NllLossBackward0>)
4 torch.float32 torch.float16 torch.float32 tensor([1, 4, 3, 1, 2, 1, 1, 0, 0, 4], device='cuda:0')
5 torch.float32
6 torch.float32
cpu
1 tensor(1.5414, grad_fn=<NllLossBackward0>)
2 torch.bfloat16
3 tensor(1.5469, dtype=torch.bfloat16, grad_fn=<NllLossBackward0>)
4 torch.float32 torch.bfloat16 torch.bfloat16 tensor([3, 2, 2, 2, 3, 0, 2, 0, 2, 2])
5 torch.bfloat16
6 torch.bfloat16
cpu
1 tensor(1.7708, grad_fn=<NllLossBackward0>)
2 expected scalar type BFloat16 but found Float
3 expected scalar type BFloat16 but found Float
4 torch.float32 torch.bfloat16 torch.float32 tensor([1, 4, 2, 0, 4, 4, 3, 2, 2, 4])
5 expected scalar type BFloat16 but found Float
6 torch.float32