Gradients type in torch.cuda.amp

Hi, I have some questions about torch.cuda.amp. I want to use mixed precision training in my project and I notice that the grads of network parameters is still float32, which is really bothering me. I consider that the fp16 params should have fp16 gradients and fp32 params should have fp32 gradients in mixed precision training. So ,is there any wrong with me? Can anyone help me ?
The following is my test code:

class CustomModel(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(CustomModel, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc3 = nn.Linear(hidden_size, num_classes)
        
    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        out = self.fc3(out)
        return out

# 初始化模型
input_size = X_train_tensor.shape[1]
hidden_size = 16000
num_classes = 2000
model = CustomModel(input_size, hidden_size, num_classes).to('cuda')
scaler = torch.cuda.amp.GradScaler()
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 训练模型
num_epochs = 1
for epoch in range(num_epochs):
    optimizer.zero_grad()
    with torch.cuda.amp.autocast(dtype=torch.float16, enabled=True):
        outputs = model(X_train_tensor)
        loss = criterion(outputs, y_train_tensor)
        
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()
    
    print(outputs.dtype)
    print(model.fc1.weight.grad.dtype)
    print(model.fc2.weight.grad.dtype)

print("Done!")

and the output is

torch.float16
torch.float32
torch.float32
Done!

@ ptrblck Do you have any comments? Thank you very much!

All parameters are in float32 using amp. If you want float16 parameters you would need to manually cast them and might run into the risk of divergence etc.

I don’t understand why all parameters are in float32 using amp. And where’s the mixed precision? You mean that as all parameters are in float32, and the gradients are in float32? If so, is it necessary to use grad scaler? I consider that grad scaler makes sense in gradients which is float16.